[mlir][IR] Refactor the `getChecked` and `verifyConstructionInvariants` methods on Attributes/Types

`verifyConstructionInvariants` is intended to allow for verifying the invariants of an attribute/type on construction, and `getChecked` is intended to enable more graceful error handling aside from an assert. There are a few problems with the current implementation of these methods:
* `verifyConstructionInvariants` requires an mlir::Location for emitting errors, which is prohibitively costly in the situations that would most likely use them, e.g. the parser.
This creates an unfortunate code duplication between the verifier code and the parser code, given that the parser operates on llvm::SMLoc and it is an undesirable overhead to pre-emptively convert from that to an mlir::Location.
* `getChecked` effectively requires duplicating the definition of the `get` method, creating a quite clunky workflow due to the subtle different in its signature.

This revision aims to talk the above problems by refactoring the implementation to use a callback for error emission. Using a callback allows for deferring the costly part of error emission until it is actually necessary.

Due to the necessary signature change in each instance of these methods, this revision also takes this opportunity to cleanup the definition of these methods by:
* restructuring the signature of `getChecked` such that it can be generated from the same code block as the `get` method.
* renaming `verifyConstructionInvariants` to `verify` to match the naming scheme of the rest of the compiler.

Differential Revision: https://reviews.llvm.org/D97100
This commit is contained in:
River Riddle 2021-02-22 17:30:19 -08:00
parent 662402a8b3
commit 06e25d5645
38 changed files with 766 additions and 709 deletions

View File

@ -23,8 +23,7 @@
namespace llvm {
class raw_ostream;
class StringRef;
template <typename>
class ArrayRef;
template <typename> class ArrayRef;
class hash_code;
} // namespace llvm
@ -149,8 +148,9 @@ public:
mlir::Type getEleTy() const;
static mlir::LogicalResult verifyConstructionInvariants(mlir::Location,
mlir::Type eleTy);
static mlir::LogicalResult
verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
mlir::Type eleTy);
};
/// The type of a LEN parameter name. Implementations may defer the layout of a
@ -174,8 +174,9 @@ public:
mlir::Type getEleTy() const;
static mlir::LogicalResult verifyConstructionInvariants(mlir::Location,
mlir::Type eleTy);
static mlir::LogicalResult
verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
mlir::Type eleTy);
};
/// The type of a reference to an entity in memory.
@ -188,8 +189,9 @@ public:
mlir::Type getEleTy() const;
static mlir::LogicalResult verifyConstructionInvariants(mlir::Location,
mlir::Type eleTy);
static mlir::LogicalResult
verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
mlir::Type eleTy);
};
/// A sequence type is a multi-dimensional array of values. The sequence type
@ -239,8 +241,8 @@ public:
static constexpr Extent getUnknownExtent() { return -1; }
static mlir::LogicalResult
verifyConstructionInvariants(mlir::Location loc, const Shape &shape,
mlir::Type eleTy, mlir::AffineMapAttr map);
verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
const Shape &shape, mlir::Type eleTy, mlir::AffineMapAttr map);
};
bool operator==(const SequenceType::Shape &, const SequenceType::Shape &);
@ -256,8 +258,9 @@ public:
static TypeDescType get(mlir::Type ofType);
mlir::Type getOfTy() const;
static mlir::LogicalResult verifyConstructionInvariants(mlir::Location,
mlir::Type ofType);
static mlir::LogicalResult
verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
mlir::Type ofType);
};
// Derived types
@ -290,8 +293,9 @@ public:
detail::RecordTypeStorage const *uniqueKey() const;
static mlir::LogicalResult verifyConstructionInvariants(mlir::Location,
llvm::StringRef name);
static mlir::LogicalResult
verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
llvm::StringRef name);
};
/// Is `t` a FIR Real or MLIR Float type?
@ -318,7 +322,8 @@ public:
uint64_t getLen() const;
static mlir::LogicalResult
verifyConstructionInvariants(mlir::Location, uint64_t len, mlir::Type eleTy);
verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError, uint64_t len,
mlir::Type eleTy);
static bool isValidElementType(mlir::Type t) {
return isa_real(t) || isa_integer(t);
}

View File

@ -68,7 +68,7 @@ def fir_BoxProcType : FIR_Type<"BoxProc", "boxproc"> {
}];
let genAccessors = 1;
let genVerifyInvariantsDecl = 1;
let genVerifyDecl = 1;
}
def fir_BoxType : FIR_Type<"Box", "box"> {
@ -91,7 +91,7 @@ def fir_BoxType : FIR_Type<"Box", "box"> {
}];
let genAccessors = 1;
let genVerifyInvariantsDecl = 1;
let genVerifyDecl = 1;
}
def fir_CharacterType : FIR_Type<"Character", "char"> {

View File

@ -681,8 +681,7 @@ private:
} // namespace detail
template <typename A, typename B>
bool inbounds(A v, B lb, B ub) {
template <typename A, typename B> bool inbounds(A v, B lb, B ub) {
return v >= lb && v < ub;
}
@ -759,8 +758,8 @@ RealType fir::RealType::get(mlir::MLIRContext *ctxt, KindTy kind) {
KindTy fir::RealType::getFKind() const { return getImpl()->getFKind(); }
mlir::LogicalResult
fir::BoxType::verifyConstructionInvariants(mlir::Location, mlir::Type eleTy,
mlir::AffineMapAttr map) {
fir::BoxType::verify(llvm::function_ref<mlir::InFlightDiagnostic()>,
mlir::Type eleTy, mlir::AffineMapAttr map) {
// TODO
return mlir::success();
}
@ -775,15 +774,14 @@ mlir::Type fir::ReferenceType::getEleTy() const {
return getImpl()->getElementType();
}
mlir::LogicalResult
fir::ReferenceType::verifyConstructionInvariants(mlir::Location loc,
mlir::Type eleTy) {
mlir::LogicalResult fir::ReferenceType::verify(
llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
mlir::Type eleTy) {
if (eleTy.isa<ShapeType>() || eleTy.isa<ShapeShiftType>() ||
eleTy.isa<SliceType>() || eleTy.isa<FieldType>() ||
eleTy.isa<LenType>() || eleTy.isa<ReferenceType>() ||
eleTy.isa<TypeDescType>())
return mlir::emitError(loc, "cannot build a reference to type: ")
<< eleTy << '\n';
return emitError() << "cannot build a reference to type: " << eleTy << '\n';
return mlir::success();
}
@ -807,12 +805,11 @@ static bool canBePointerOrHeapElementType(mlir::Type eleTy) {
eleTy.isa<ReferenceType>() || eleTy.isa<TypeDescType>();
}
mlir::LogicalResult
fir::PointerType::verifyConstructionInvariants(mlir::Location loc,
mlir::Type eleTy) {
mlir::LogicalResult fir::PointerType::verify(
llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
mlir::Type eleTy) {
if (canBePointerOrHeapElementType(eleTy))
return mlir::emitError(loc, "cannot build a pointer to type: ")
<< eleTy << '\n';
return emitError() << "cannot build a pointer to type: " << eleTy << '\n';
return mlir::success();
}
@ -828,11 +825,11 @@ mlir::Type fir::HeapType::getEleTy() const {
}
mlir::LogicalResult
fir::HeapType::verifyConstructionInvariants(mlir::Location loc,
mlir::Type eleTy) {
fir::HeapType::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
mlir::Type eleTy) {
if (canBePointerOrHeapElementType(eleTy))
return mlir::emitError(loc, "cannot build a heap pointer to type: ")
<< eleTy << '\n';
return emitError() << "cannot build a heap pointer to type: " << eleTy
<< '\n';
return mlir::success();
}
@ -884,8 +881,9 @@ bool fir::SequenceType::hasConstantInterior() const {
return true;
}
mlir::LogicalResult fir::SequenceType::verifyConstructionInvariants(
mlir::Location loc, const SequenceType::Shape &shape, mlir::Type eleTy,
mlir::LogicalResult fir::SequenceType::verify(
llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
const SequenceType::Shape &shape, mlir::Type eleTy,
mlir::AffineMapAttr map) {
// DIMENSION attribute can only be applied to an intrinsic or record type
if (eleTy.isa<BoxType>() || eleTy.isa<BoxCharType>() ||
@ -895,8 +893,8 @@ mlir::LogicalResult fir::SequenceType::verifyConstructionInvariants(
eleTy.isa<PointerType>() || eleTy.isa<ReferenceType>() ||
eleTy.isa<TypeDescType>() || eleTy.isa<fir::VectorType>() ||
eleTy.isa<SequenceType>())
return mlir::emitError(loc, "cannot build an array of this element type: ")
<< eleTy << '\n';
return emitError() << "cannot build an array of this element type: "
<< eleTy << '\n';
return mlir::success();
}
@ -955,11 +953,11 @@ detail::RecordTypeStorage const *fir::RecordType::uniqueKey() const {
return getImpl();
}
mlir::LogicalResult
fir::RecordType::verifyConstructionInvariants(mlir::Location loc,
llvm::StringRef name) {
mlir::LogicalResult fir::RecordType::verify(
llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
llvm::StringRef name) {
if (name.size() == 0)
return mlir::emitError(loc, "record types must have a name");
return emitError() << "record types must have a name";
return mlir::success();
}
@ -981,16 +979,16 @@ TypeDescType fir::TypeDescType::get(mlir::Type ofType) {
mlir::Type fir::TypeDescType::getOfTy() const { return getImpl()->getOfType(); }
mlir::LogicalResult
fir::TypeDescType::verifyConstructionInvariants(mlir::Location loc,
mlir::Type eleTy) {
mlir::LogicalResult fir::TypeDescType::verify(
llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
mlir::Type eleTy) {
if (eleTy.isa<BoxType>() || eleTy.isa<BoxCharType>() ||
eleTy.isa<BoxProcType>() || eleTy.isa<ShapeType>() ||
eleTy.isa<ShapeShiftType>() || eleTy.isa<SliceType>() ||
eleTy.isa<FieldType>() || eleTy.isa<LenType>() ||
eleTy.isa<ReferenceType>() || eleTy.isa<TypeDescType>())
return mlir::emitError(loc, "cannot build a type descriptor of type: ")
<< eleTy << '\n';
return emitError() << "cannot build a type descriptor of type: " << eleTy
<< '\n';
return mlir::success();
}
@ -1006,12 +1004,11 @@ mlir::Type fir::VectorType::getEleTy() const { return getImpl()->getEleTy(); }
uint64_t fir::VectorType::getLen() const { return getImpl()->getLen(); }
mlir::LogicalResult
fir::VectorType::verifyConstructionInvariants(mlir::Location loc, uint64_t len,
mlir::Type eleTy) {
mlir::LogicalResult fir::VectorType::verify(
llvm::function_ref<mlir::InFlightDiagnostic()> emitError, uint64_t len,
mlir::Type eleTy) {
if (!(fir::isa_real(eleTy) || fir::isa_integer(eleTy)))
return mlir::emitError(loc, "cannot build a vector of type ")
<< eleTy << '\n';
return emitError() << "cannot build a vector of type " << eleTy << '\n';
return mlir::success();
}
@ -1173,14 +1170,14 @@ mlir::Type BoxProcType::parse(mlir::MLIRContext *context,
}
mlir::LogicalResult
BoxProcType::verifyConstructionInvariants(mlir::Location loc,
mlir::Type eleTy) {
BoxProcType::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
mlir::Type eleTy) {
if (eleTy.isa<mlir::FunctionType>())
return mlir::success();
if (auto refTy = eleTy.dyn_cast<ReferenceType>())
if (refTy.isa<mlir::FunctionType>())
return mlir::success();
return mlir::emitError(loc, "invalid type for boxproc") << eleTy << '\n';
return emitError() << "invalid type for boxproc" << eleTy << '\n';
}
//===----------------------------------------------------------------------===//

View File

@ -1525,11 +1525,10 @@ responsible for parsing/printing the types in `Dialect::printType` and
- If the `genAccessors` field is 1 (the default) accessor methods will be
generated on the Type class (e.g. `int getWidth() const` in the example
above).
- If the `genVerifyInvariantsDecl` field is set, a declaration for a method
`static LogicalResult verifyConstructionInvariants(Location, parameters...)`
is added to the class as well as a `getChecked(Location, parameters...)`
method which gets the result of `verifyConstructionInvariants` before
calling `get`.
- If the `genVerifyDecl` field is set, a declaration for a method `static
LogicalResult verify(emitErrorFn, parameters...)` is added to the class as
well as a `getChecked(emitErrorFn, parameters...)` method which checks the
result of `verify` before calling `get`.
- The `storageClass` field can be used to set the name of the storage class.
- The `storageNamespace` field is used to set the namespace where the storage
class should sit. Defaults to "detail".
@ -1555,9 +1554,9 @@ The following builders are generated:
// given set of parameters.
static MyType get(MLIRContext *context, int intParam);
// If `genVerifyInvariantsDecl` is set to 1, the following method is also
// generated.
static MyType getChecked(Location loc, int intParam);
// If `genVerifyDecl` is set to 1, the following method is also generated.
static MyType getChecked(function_ref<InFlightDiagnostic()> emitError,
MLIRContext *context, int intParam);
```
If these autogenerated methods are not desired, such as when they conflict with

View File

@ -161,27 +161,28 @@ public:
return Base::get(type.getContext(), param, type);
}
/// This method is used to get an instance of the 'ComplexType', defined at
/// the given location. If any of the construction invariants are invalid,
/// errors are emitted with the provided location and a null type is returned.
/// This method is used to get an instance of the 'ComplexType'. If any of the
/// construction invariants are invalid, errors are emitted with the provided
/// `emitError` function and a null type is returned.
/// Note: This method is completely optional.
static ComplexType getChecked(unsigned param, Type type, Location location) {
static ComplexType getChecked(function_ref<InFlightDiagnostic()> emitError,
unsigned param, Type type) {
// Call into a helper 'getChecked' method in 'TypeBase' to get a uniqued
// instance of this type. All parameters to the storage class are passed
// after the location.
return Base::getChecked(location, param, type);
// after the context.
return Base::getChecked(emitError, type.getContext(), param, type);
}
/// This method is used to verify the construction invariants passed into the
/// 'get' and 'getChecked' methods. Note: This method is completely optional.
static LogicalResult verifyConstructionInvariants(
Location loc, unsigned param, Type type) {
static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
unsigned param, Type type) {
// Our type only allows non-zero parameters.
if (param == 0)
return emitError(loc) << "non-zero parameter passed to 'ComplexType'";
return emitError() << "non-zero parameter passed to 'ComplexType'";
// Our type also expects an integer type.
if (!type.isa<IntegerType>())
return emitError(loc) << "non integer-type passed to 'ComplexType'";
return emitError() << "non integer-type passed to 'ComplexType'";
return success();
}

View File

@ -98,8 +98,9 @@ MLIR_CAPI_EXPORTED MlirAttribute mlirFloatAttrDoubleGet(MlirContext ctx,
/// Same as "mlirFloatAttrDoubleGet", but if the type is not valid for a
/// construction of a FloatAttr, returns a null MlirAttribute.
MLIR_CAPI_EXPORTED MlirAttribute
mlirFloatAttrDoubleGetChecked(MlirType type, double value, MlirLocation loc);
MLIR_CAPI_EXPORTED MlirAttribute mlirFloatAttrDoubleGetChecked(MlirLocation loc,
MlirType type,
double value);
/// Returns the value stored in the given floating point attribute, interpreting
/// the value as double.

View File

@ -170,10 +170,10 @@ MLIR_CAPI_EXPORTED MlirType mlirVectorTypeGet(intptr_t rank,
/// Same as "mlirVectorTypeGet" but returns a nullptr wrapping MlirType on
/// illegal arguments, emitting appropriate diagnostics.
MLIR_CAPI_EXPORTED MlirType mlirVectorTypeGetChecked(intptr_t rank,
MLIR_CAPI_EXPORTED MlirType mlirVectorTypeGetChecked(MlirLocation loc,
intptr_t rank,
const int64_t *shape,
MlirType elementType,
MlirLocation loc);
MlirType elementType);
//===----------------------------------------------------------------------===//
// Ranked / Unranked Tensor type.
@ -196,10 +196,9 @@ MLIR_CAPI_EXPORTED MlirType mlirRankedTensorTypeGet(intptr_t rank,
/// Same as "mlirRankedTensorTypeGet" but returns a nullptr wrapping MlirType on
/// illegal arguments, emitting appropriate diagnostics.
MLIR_CAPI_EXPORTED MlirType mlirRankedTensorTypeGetChecked(intptr_t rank,
const int64_t *shape,
MlirType elementType,
MlirLocation loc);
MLIR_CAPI_EXPORTED MlirType
mlirRankedTensorTypeGetChecked(MlirLocation loc, intptr_t rank,
const int64_t *shape, MlirType elementType);
/// Creates an unranked tensor type with the given element type in the same
/// context as the element type. The type is owned by the context.
@ -208,7 +207,7 @@ MLIR_CAPI_EXPORTED MlirType mlirUnrankedTensorTypeGet(MlirType elementType);
/// Same as "mlirUnrankedTensorTypeGet" but returns a nullptr wrapping MlirType
/// on illegal arguments, emitting appropriate diagnostics.
MLIR_CAPI_EXPORTED MlirType
mlirUnrankedTensorTypeGetChecked(MlirType elementType, MlirLocation loc);
mlirUnrankedTensorTypeGetChecked(MlirLocation loc, MlirType elementType);
//===----------------------------------------------------------------------===//
// Ranked / Unranked MemRef type.
@ -230,8 +229,8 @@ MLIR_CAPI_EXPORTED MlirType mlirMemRefTypeGet(
/// Same as "mlirMemRefTypeGet" but returns a nullptr-wrapping MlirType o
/// illegal arguments, emitting appropriate diagnostics.
MLIR_CAPI_EXPORTED MlirType mlirMemRefTypeGetChecked(
MlirType elementType, intptr_t rank, const int64_t *shape, intptr_t numMaps,
MlirAffineMap const *affineMaps, unsigned memorySpace, MlirLocation loc);
MlirLocation loc, MlirType elementType, intptr_t rank, const int64_t *shape,
intptr_t numMaps, MlirAffineMap const *affineMaps, unsigned memorySpace);
/// Creates a MemRef type with the given rank, shape, memory space and element
/// type in the same context as the element type. The type has no affine maps,
@ -245,8 +244,8 @@ MLIR_CAPI_EXPORTED MlirType mlirMemRefTypeContiguousGet(MlirType elementType,
/// Same as "mlirMemRefTypeContiguousGet" but returns a nullptr wrapping
/// MlirType on illegal arguments, emitting appropriate diagnostics.
MLIR_CAPI_EXPORTED MlirType mlirMemRefTypeContiguousGetChecked(
MlirType elementType, intptr_t rank, const int64_t *shape,
unsigned memorySpace, MlirLocation loc);
MlirLocation loc, MlirType elementType, intptr_t rank, const int64_t *shape,
unsigned memorySpace);
/// Creates an Unranked MemRef type with the given element type and in the given
/// memory space. The type is owned by the context of element type.
@ -256,7 +255,7 @@ MLIR_CAPI_EXPORTED MlirType mlirUnrankedMemRefTypeGet(MlirType elementType,
/// Same as "mlirUnrankedMemRefTypeGet" but returns a nullptr wrapping
/// MlirType on illegal arguments, emitting appropriate diagnostics.
MLIR_CAPI_EXPORTED MlirType mlirUnrankedMemRefTypeGetChecked(
MlirType elementType, unsigned memorySpace, MlirLocation loc);
MlirLocation loc, MlirType elementType, unsigned memorySpace);
/// Returns the number of affine layout maps in the given MemRef type.
MLIR_CAPI_EXPORTED intptr_t mlirMemRefTypeGetNumAffineMaps(MlirType type);

View File

@ -43,9 +43,7 @@ def Async_ValueType : Async_Type<"Value", "value"> {
let parameters = (ins "Type":$valueType);
let builders = [
TypeBuilderWithInferredContext<(ins "Type":$valueType), [{
return Base::get(valueType.getContext(), valueType);
}], [{
return Base::getChecked($_loc, valueType);
return $_get(valueType.getContext(), valueType);
}]>
];
let skipDefaultBuilders = 1;

View File

@ -68,6 +68,7 @@ class LLVMArrayType : public Type::TypeBase<LLVMArrayType, Type,
public:
/// Inherit base constructors.
using Base::Base;
using Base::getChecked;
/// Checks if the given type can be used inside an array type.
static bool isValidElementType(Type type);
@ -75,8 +76,8 @@ public:
/// Gets or creates an instance of LLVM dialect array type containing
/// `numElements` of `elementType`, in the same context as `elementType`.
static LLVMArrayType get(Type elementType, unsigned numElements);
static LLVMArrayType getChecked(Location loc, Type elementType,
unsigned numElements);
static LLVMArrayType getChecked(function_ref<InFlightDiagnostic()> emitError,
Type elementType, unsigned numElements);
/// Returns the element type of the array.
Type getElementType();
@ -85,9 +86,8 @@ public:
unsigned getNumElements();
/// Verifies that the type about to be constructed is well-formed.
static LogicalResult verifyConstructionInvariants(Location loc,
Type elementType,
unsigned numElements);
static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
Type elementType, unsigned numElements);
};
//===----------------------------------------------------------------------===//
@ -103,6 +103,7 @@ class LLVMFunctionType
public:
/// Inherit base constructors.
using Base::Base;
using Base::getChecked;
/// Checks if the given type can be used an argument in a function type.
static bool isValidArgumentType(Type type);
@ -117,9 +118,9 @@ public:
/// as the `result` type.
static LLVMFunctionType get(Type result, ArrayRef<Type> arguments,
bool isVarArg = false);
static LLVMFunctionType getChecked(Location loc, Type result,
ArrayRef<Type> arguments,
bool isVarArg = false);
static LLVMFunctionType
getChecked(function_ref<InFlightDiagnostic()> emitError, Type result,
ArrayRef<Type> arguments, bool isVarArg = false);
/// Returns the result type of the function.
Type getReturnType();
@ -135,9 +136,8 @@ public:
ArrayRef<Type> params() { return getParams(); }
/// Verifies that the type about to be constructed is well-formed.
static LogicalResult verifyConstructionInvariants(Location loc, Type result,
ArrayRef<Type> arguments,
bool);
static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
Type result, ArrayRef<Type> arguments, bool);
};
//===----------------------------------------------------------------------===//
@ -152,6 +152,7 @@ class LLVMPointerType : public Type::TypeBase<LLVMPointerType, Type,
public:
/// Inherit base constructors.
using Base::Base;
using Base::getChecked;
/// Checks if the given type can have a pointer type pointing to it.
static bool isValidElementType(Type type);
@ -160,8 +161,9 @@ public:
/// object of `pointee` type in the given address space. The pointer type is
/// created in the same context as `pointee`.
static LLVMPointerType get(Type pointee, unsigned addressSpace = 0);
static LLVMPointerType getChecked(Location loc, Type pointee,
unsigned addressSpace = 0);
static LLVMPointerType
getChecked(function_ref<InFlightDiagnostic()> emitError, Type pointee,
unsigned addressSpace = 0);
/// Returns the pointed-to type.
Type getElementType();
@ -170,8 +172,8 @@ public:
unsigned getAddressSpace();
/// Verifies that the type about to be constructed is well-formed.
static LogicalResult verifyConstructionInvariants(Location loc, Type pointee,
unsigned);
static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
Type pointee, unsigned);
};
//===----------------------------------------------------------------------===//
@ -217,7 +219,9 @@ public:
/// in the context. Instead, it will just return the existing struct,
/// similarly to the rest of MLIR type ::get methods.
static LLVMStructType getIdentified(MLIRContext *context, StringRef name);
static LLVMStructType getIdentifiedChecked(Location loc, StringRef name);
static LLVMStructType
getIdentifiedChecked(function_ref<InFlightDiagnostic()> emitError,
MLIRContext *context, StringRef name);
/// Gets a new identified struct with the given body. The body _cannot_ be
/// changed later. If a struct with the given name already exists, renames
@ -231,8 +235,10 @@ public:
/// context.
static LLVMStructType getLiteral(MLIRContext *context, ArrayRef<Type> types,
bool isPacked = false);
static LLVMStructType getLiteralChecked(Location loc, ArrayRef<Type> types,
bool isPacked = false);
static LLVMStructType
getLiteralChecked(function_ref<InFlightDiagnostic()> emitError,
MLIRContext *context, ArrayRef<Type> types,
bool isPacked = false);
/// Gets or creates an intentionally-opaque identified struct. Such a struct
/// cannot have its body set. To create an opaque struct with a mutable body,
@ -241,7 +247,9 @@ public:
/// already exists in the context. Instead, it will just return the existing
/// struct, similarly to the rest of MLIR type ::get methods.
static LLVMStructType getOpaque(StringRef name, MLIRContext *context);
static LLVMStructType getOpaqueChecked(Location loc, StringRef name);
static LLVMStructType
getOpaqueChecked(function_ref<InFlightDiagnostic()> emitError,
MLIRContext *context, StringRef name);
/// Set the body of an identified struct. Returns failure if the body could
/// not be set, e.g. if the struct already has a body or if it was marked as
@ -270,9 +278,10 @@ public:
ArrayRef<Type> getBody();
/// Verifies that the type about to be constructed is well-formed.
static LogicalResult verifyConstructionInvariants(Location, StringRef, bool);
static LogicalResult verifyConstructionInvariants(Location loc,
ArrayRef<Type> types, bool);
static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
StringRef, bool);
static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
ArrayRef<Type> types, bool);
};
//===----------------------------------------------------------------------===//
@ -300,9 +309,8 @@ public:
llvm::ElementCount getElementCount();
/// Verifies that the type about to be constructed is well-formed.
static LogicalResult verifyConstructionInvariants(Location loc,
Type elementType,
unsigned numElements);
static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
Type elementType, unsigned numElements);
};
//===----------------------------------------------------------------------===//
@ -317,12 +325,14 @@ class LLVMFixedVectorType
public:
/// Inherit base constructor.
using Base::Base;
using Base::getChecked;
/// Gets or creates a fixed vector type containing `numElements` of
/// `elementType` in the same context as `elementType`.
static LLVMFixedVectorType get(Type elementType, unsigned numElements);
static LLVMFixedVectorType getChecked(Location loc, Type elementType,
unsigned numElements);
static LLVMFixedVectorType
getChecked(function_ref<InFlightDiagnostic()> emitError, Type elementType,
unsigned numElements);
/// Checks if the given type can be used in a vector type. This type supports
/// only a subset of LLVM dialect types that don't have a built-in
@ -336,9 +346,8 @@ public:
unsigned getNumElements();
/// Verifies that the type about to be constructed is well-formed.
static LogicalResult verifyConstructionInvariants(Location loc,
Type elementType,
unsigned numElements);
static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
Type elementType, unsigned numElements);
};
//===----------------------------------------------------------------------===//
@ -354,12 +363,14 @@ class LLVMScalableVectorType
public:
/// Inherit base constructor.
using Base::Base;
using Base::getChecked;
/// Gets or creates a scalable vector type containing a non-zero multiple of
/// `minNumElements` of `elementType` in the same context as `elementType`.
static LLVMScalableVectorType get(Type elementType, unsigned minNumElements);
static LLVMScalableVectorType getChecked(Location loc, Type elementType,
unsigned minNumElements);
static LLVMScalableVectorType
getChecked(function_ref<InFlightDiagnostic()> emitError, Type elementType,
unsigned minNumElements);
/// Checks if the given type can be used in a vector type.
static bool isValidElementType(Type type);
@ -373,9 +384,8 @@ public:
unsigned getMinNumElements();
/// Verifies that the type about to be constructed is well-formed.
static LogicalResult verifyConstructionInvariants(Location loc,
Type elementType,
unsigned minNumElements);
static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
Type elementType, unsigned minNumElements);
};
//===----------------------------------------------------------------------===//

View File

@ -57,10 +57,10 @@ public:
/// The maximum number of bits supported for storage types.
static constexpr unsigned MaxStorageBits = 32;
static LogicalResult
verifyConstructionInvariants(Location loc, unsigned flags, Type storageType,
Type expressedType, int64_t storageTypeMin,
int64_t storageTypeMax);
static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
unsigned flags, Type storageType,
Type expressedType, int64_t storageTypeMin,
int64_t storageTypeMax);
/// Support method to enable LLVM-style type casting.
static bool classof(Type type);
@ -199,6 +199,7 @@ class AnyQuantizedType
detail::AnyQuantizedTypeStorage> {
public:
using Base::Base;
using Base::getChecked;
/// Gets an instance of the type with all parameters specified but not
/// checked.
@ -208,15 +209,16 @@ public:
/// Gets an instance of the type with all specified parameters checked.
/// Returns a nullptr convertible type on failure.
static AnyQuantizedType getChecked(unsigned flags, Type storageType,
Type expressedType, int64_t storageTypeMin,
int64_t storageTypeMax, Location location);
static AnyQuantizedType
getChecked(function_ref<InFlightDiagnostic()> emitError, unsigned flags,
Type storageType, Type expressedType, int64_t storageTypeMin,
int64_t storageTypeMax);
/// Verifies construction invariants and issues errors/warnings.
static LogicalResult
verifyConstructionInvariants(Location loc, unsigned flags, Type storageType,
Type expressedType, int64_t storageTypeMin,
int64_t storageTypeMax);
static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
unsigned flags, Type storageType,
Type expressedType, int64_t storageTypeMin,
int64_t storageTypeMax);
};
/// Represents a family of uniform, quantized types.
@ -256,6 +258,7 @@ class UniformQuantizedType
detail::UniformQuantizedTypeStorage> {
public:
using Base::Base;
using Base::getChecked;
/// Gets an instance of the type with all parameters specified but not
/// checked.
@ -267,16 +270,16 @@ public:
/// Gets an instance of the type with all specified parameters checked.
/// Returns a nullptr convertible type on failure.
static UniformQuantizedType
getChecked(unsigned flags, Type storageType, Type expressedType, double scale,
int64_t zeroPoint, int64_t storageTypeMin, int64_t storageTypeMax,
Location location);
getChecked(function_ref<InFlightDiagnostic()> emitError, unsigned flags,
Type storageType, Type expressedType, double scale,
int64_t zeroPoint, int64_t storageTypeMin, int64_t storageTypeMax);
/// Verifies construction invariants and issues errors/warnings.
static LogicalResult
verifyConstructionInvariants(Location loc, unsigned flags, Type storageType,
Type expressedType, double scale,
int64_t zeroPoint, int64_t storageTypeMin,
int64_t storageTypeMax);
static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
unsigned flags, Type storageType,
Type expressedType, double scale,
int64_t zeroPoint, int64_t storageTypeMin,
int64_t storageTypeMax);
/// Gets the scale term. The scale designates the difference between the real
/// values corresponding to consecutive quantized values differing by 1.
@ -313,6 +316,7 @@ class UniformQuantizedPerAxisType
detail::UniformQuantizedPerAxisTypeStorage> {
public:
using Base::Base;
using Base::getChecked;
/// Gets an instance of the type with all parameters specified but not
/// checked.
@ -325,18 +329,18 @@ public:
/// Gets an instance of the type with all specified parameters checked.
/// Returns a nullptr convertible type on failure.
static UniformQuantizedPerAxisType
getChecked(unsigned flags, Type storageType, Type expressedType,
ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints,
int32_t quantizedDimension, int64_t storageTypeMin,
int64_t storageTypeMax, Location location);
getChecked(function_ref<InFlightDiagnostic()> emitError, unsigned flags,
Type storageType, Type expressedType, ArrayRef<double> scales,
ArrayRef<int64_t> zeroPoints, int32_t quantizedDimension,
int64_t storageTypeMin, int64_t storageTypeMax);
/// Verifies construction invariants and issues errors/warnings.
static LogicalResult
verifyConstructionInvariants(Location loc, unsigned flags, Type storageType,
Type expressedType, ArrayRef<double> scales,
ArrayRef<int64_t> zeroPoints,
int32_t quantizedDimension,
int64_t storageTypeMin, int64_t storageTypeMax);
static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
unsigned flags, Type storageType,
Type expressedType, ArrayRef<double> scales,
ArrayRef<int64_t> zeroPoints,
int32_t quantizedDimension,
int64_t storageTypeMin, int64_t storageTypeMax);
/// Gets the quantization scales. The scales designate the difference between
/// the real values corresponding to consecutive quantized values differing
@ -381,6 +385,7 @@ class CalibratedQuantizedType
detail::CalibratedQuantizedTypeStorage> {
public:
using Base::Base;
using Base::getChecked;
/// Gets an instance of the type with all parameters specified but not
/// checked.
@ -389,13 +394,13 @@ public:
/// Gets an instance of the type with all specified parameters checked.
/// Returns a nullptr convertible type on failure.
static CalibratedQuantizedType getChecked(Type expressedType, double min,
double max, Location location);
static CalibratedQuantizedType
getChecked(function_ref<InFlightDiagnostic()> emitError, Type expressedType,
double min, double max);
/// Verifies construction invariants and issues errors/warnings.
static LogicalResult verifyConstructionInvariants(Location loc,
Type expressedType,
double min, double max);
static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
Type expressedType, double min, double max);
double getMin() const;
double getMax() const;
};

View File

@ -69,10 +69,9 @@ public:
/// Returns `spirv::StorageClass`.
Optional<StorageClass> getStorageClass();
static LogicalResult verifyConstructionInvariants(Location loc,
IntegerAttr descriptorSet,
IntegerAttr binding,
IntegerAttr storageClass);
static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
IntegerAttr descriptorSet, IntegerAttr binding,
IntegerAttr storageClass);
};
/// An attribute that specifies the SPIR-V (version, capabilities, extensions)
@ -120,10 +119,9 @@ public:
/// Returns the capabilities as an integer array attribute.
ArrayAttr getCapabilitiesAttr();
static LogicalResult verifyConstructionInvariants(Location loc,
IntegerAttr version,
ArrayAttr capabilities,
ArrayAttr extensions);
static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
IntegerAttr version, ArrayAttr capabilities,
ArrayAttr extensions);
};
/// An attribute that specifies the target version, allowed extensions and
@ -174,10 +172,10 @@ public:
/// Returns the target resource limits.
ResourceLimitsAttr getResourceLimits() const;
static LogicalResult
verifyConstructionInvariants(Location loc, VerCapExtAttr triple,
Vendor vendorID, DeviceType deviceType,
uint32_t deviceID, DictionaryAttr limits);
static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
VerCapExtAttr triple, Vendor vendorID,
DeviceType deviceType, uint32_t deviceID,
DictionaryAttr limits);
};
} // namespace spirv
} // namespace mlir

View File

@ -243,10 +243,11 @@ public:
static SampledImageType get(Type imageType);
static SampledImageType getChecked(Type imageType, Location location);
static SampledImageType
getChecked(function_ref<InFlightDiagnostic()> emitError, Type imageType);
static LogicalResult verifyConstructionInvariants(Location Loc,
Type imageType);
static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
Type imageType);
Type getImageType() const;
@ -426,12 +427,11 @@ public:
static MatrixType get(Type columnType, uint32_t columnCount);
static MatrixType getChecked(Type columnType, uint32_t columnCount,
Location location);
static MatrixType getChecked(function_ref<InFlightDiagnostic()> emitError,
Type columnType, uint32_t columnCount);
static LogicalResult verifyConstructionInvariants(Location loc,
Type columnType,
uint32_t columnCount);
static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
Type columnType, uint32_t columnCount);
/// Returns true if the matrix elements are vectors of float elements.
static bool isValidColumnType(Type columnType);

View File

@ -182,17 +182,20 @@ class FloatAttr : public Attribute::AttrBase<FloatAttr, Attribute,
detail::FloatAttributeStorage> {
public:
using Base::Base;
using Base::getChecked;
using ValueType = APFloat;
/// Return a float attribute for the specified value in the specified type.
/// These methods should only be used for simple constant values, e.g 1.0/2.0,
/// that are known-valid both as host double and the 'type' format.
static FloatAttr get(Type type, double value);
static FloatAttr getChecked(Type type, double value, Location loc);
static FloatAttr getChecked(function_ref<InFlightDiagnostic()> emitError,
Type type, double value);
/// Return a float attribute for the specified value in the specified type.
static FloatAttr get(Type type, const APFloat &value);
static FloatAttr getChecked(Type type, const APFloat &value, Location loc);
static FloatAttr getChecked(function_ref<InFlightDiagnostic()> emitError,
Type type, const APFloat &value);
APFloat getValue() const;
@ -202,10 +205,10 @@ public:
static double getValueAsDouble(APFloat val);
/// Verify the construction invariants for a double value.
static LogicalResult verifyConstructionInvariants(Location loc, Type type,
double value);
static LogicalResult verifyConstructionInvariants(Location loc, Type type,
const APFloat &value);
static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
Type type, double value);
static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
Type type, const APFloat &value);
};
//===----------------------------------------------------------------------===//
@ -234,10 +237,10 @@ public:
/// an unsigned integer.
uint64_t getUInt() const;
static LogicalResult verifyConstructionInvariants(Location loc, Type type,
int64_t value);
static LogicalResult verifyConstructionInvariants(Location loc, Type type,
const APInt &value);
static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
Type type, int64_t value);
static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
Type type, const APInt &value);
};
//===----------------------------------------------------------------------===//
@ -290,6 +293,7 @@ class OpaqueAttr : public Attribute::AttrBase<OpaqueAttr, Attribute,
detail::OpaqueAttributeStorage> {
public:
using Base::Base;
using Base::getChecked;
/// Get or create a new OpaqueAttr with the provided dialect and string data.
static OpaqueAttr get(MLIRContext *context, Identifier dialect,
@ -298,8 +302,9 @@ public:
/// Get or create a new OpaqueAttr with the provided dialect and string data.
/// If the given identifier is not a valid namespace for a dialect, then a
/// null attribute is returned.
static OpaqueAttr getChecked(Identifier dialect, StringRef attrData,
Type type, Location location);
static OpaqueAttr getChecked(function_ref<InFlightDiagnostic()> emitError,
Identifier dialect, StringRef attrData,
Type type);
/// Returns the dialect namespace of the opaque attribute.
Identifier getDialectNamespace() const;
@ -308,10 +313,9 @@ public:
StringRef getAttrData() const;
/// Verify the construction of an opaque attribute.
static LogicalResult verifyConstructionInvariants(Location loc,
Identifier dialect,
StringRef attrData,
Type type);
static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
Identifier dialect, StringRef attrData,
Type type);
};
//===----------------------------------------------------------------------===//
@ -428,10 +432,8 @@ public:
//===----------------------------------------------------------------------===//
namespace detail {
template <typename T>
class ElementsAttrIterator;
template <typename T>
class ElementsAttrRange;
template <typename T> class ElementsAttrIterator;
template <typename T> class ElementsAttrRange;
} // namespace detail
/// A base attribute that represents a reference to a static shaped tensor or
@ -439,10 +441,8 @@ class ElementsAttrRange;
class ElementsAttr : public Attribute {
public:
using Attribute::Attribute;
template <typename T>
using iterator = detail::ElementsAttrIterator<T>;
template <typename T>
using iterator_range = detail::ElementsAttrRange<T>;
template <typename T> using iterator = detail::ElementsAttrIterator<T>;
template <typename T> using iterator_range = detail::ElementsAttrRange<T>;
/// Return the type of this ElementsAttr, guaranteed to be a vector or tensor
/// with static shape.
@ -454,16 +454,14 @@ public:
/// Return the value of type 'T' at the given index, where 'T' corresponds to
/// an Attribute type.
template <typename T>
T getValue(ArrayRef<uint64_t> index) const {
template <typename T> T getValue(ArrayRef<uint64_t> index) const {
return getValue(index).template cast<T>();
}
/// Return the elements of this attribute as a value of type 'T'. Note:
/// Aborts if the subclass is OpaqueElementsAttrs, these attrs do not support
/// iteration.
template <typename T>
iterator_range<T> getValues() const;
template <typename T> iterator_range<T> getValues() const;
/// Return if the given 'index' refers to a valid element in this attribute.
bool isValidIndex(ArrayRef<uint64_t> index) const;
@ -540,8 +538,7 @@ protected:
};
/// Type trait detector that checks if a given type T is a complex type.
template <typename T>
struct is_complex_t : public std::false_type {};
template <typename T> struct is_complex_t : public std::false_type {};
template <typename T>
struct is_complex_t<std::complex<T>> : public std::true_type {};
} // namespace detail
@ -556,8 +553,7 @@ public:
/// floating point type that can be used to access the underlying element
/// types of a DenseElementsAttr.
// TODO: Use std::disjunction when C++17 is supported.
template <typename T>
struct is_valid_cpp_fp_type {
template <typename T> struct is_valid_cpp_fp_type {
/// The type is a valid floating point type if it is a builtin floating
/// point type, or is a potentially user defined floating point type. The
/// latter allows for supporting users that have custom types defined for
@ -826,8 +822,7 @@ public:
Attribute getValue(ArrayRef<uint64_t> index) const {
return getValue<Attribute>(index);
}
template <typename T>
T getValue(ArrayRef<uint64_t> index) const {
template <typename T> T getValue(ArrayRef<uint64_t> index) const {
// Skip to the element corresponding to the flattened index.
return *std::next(getValues<T>().begin(), getFlattenedIndex(index));
}
@ -1236,8 +1231,7 @@ public:
/// Return the values of this attribute in the form of the given type 'T'. 'T'
/// may be any of Attribute, APInt, APFloat, c++ integer/float types, etc.
template <typename T>
llvm::iterator_range<iterator<T>> getValues() const {
template <typename T> llvm::iterator_range<iterator<T>> getValues() const {
auto zeroValue = getZeroValue<T>();
auto valueIt = getValues().getValues<T>().begin();
const std::vector<ptrdiff_t> flatSparseIndices(getFlattenedSparseIndices());
@ -1379,28 +1373,22 @@ class ElementsAttrIterator
}
/// Utility functors used to generically implement the iterators methods.
template <typename ItT>
struct PlusAssign {
template <typename ItT> struct PlusAssign {
void operator()(ItT &it, ptrdiff_t offset) { it += offset; }
};
template <typename ItT>
struct Minus {
template <typename ItT> struct Minus {
ptrdiff_t operator()(const ItT &lhs, const ItT &rhs) { return lhs - rhs; }
};
template <typename ItT>
struct MinusAssign {
template <typename ItT> struct MinusAssign {
void operator()(ItT &it, ptrdiff_t offset) { it -= offset; }
};
template <typename ItT>
struct Dereference {
template <typename ItT> struct Dereference {
T operator()(ItT &it) { return *it; }
};
template <typename ItT>
struct ConstructIter {
template <typename ItT> struct ConstructIter {
void operator()(ItT &dest, const ItT &it) { ::new (&dest) ItT(it); }
};
template <typename ItT>
struct DestructIter {
template <typename ItT> struct DestructIter {
void operator()(ItT &it) { it.~ItT(); }
};

View File

@ -167,22 +167,21 @@ class VectorType
: public Type::TypeBase<VectorType, ShapedType, detail::VectorTypeStorage> {
public:
using Base::Base;
using Base::getChecked;
/// Get or create a new VectorType of the provided shape and element type.
/// Assumes the arguments define a well-formed VectorType.
static VectorType get(ArrayRef<int64_t> shape, Type elementType);
/// Get or create a new VectorType of the provided shape and element type
/// declared at the given, potentially unknown, location. If the VectorType
/// defined by the arguments would be ill-formed, emit errors and return
/// nullptr-wrapping type.
static VectorType getChecked(Location location, ArrayRef<int64_t> shape,
Type elementType);
/// Get or create a new VectorType of the provided shape and element type. If
/// the VectorType defined by the arguments would be ill-formed, an error is
/// emitted to `emitError` and a null type is returned.
static VectorType getChecked(function_ref<InFlightDiagnostic()> emitError,
ArrayRef<int64_t> shape, Type elementType);
/// Verify the construction of a vector type.
static LogicalResult verifyConstructionInvariants(Location loc,
ArrayRef<int64_t> shape,
Type elementType);
static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
ArrayRef<int64_t> shape, Type elementType);
/// Returns true of the given type can be used as an element of a vector type.
/// In particular, vectors can consist of integer or float primitives.
@ -226,22 +225,23 @@ class RankedTensorType
detail::RankedTensorTypeStorage> {
public:
using Base::Base;
using Base::getChecked;
/// Get or create a new RankedTensorType of the provided shape and element
/// type. Assumes the arguments define a well-formed type.
static RankedTensorType get(ArrayRef<int64_t> shape, Type elementType);
/// Get or create a new RankedTensorType of the provided shape and element
/// type declared at the given, potentially unknown, location. If the
/// RankedTensorType defined by the arguments would be ill-formed, emit errors
/// and return a nullptr-wrapping type.
static RankedTensorType getChecked(Location location, ArrayRef<int64_t> shape,
Type elementType);
/// type. If the RankedTensorType defined by the arguments would be
/// ill-formed, an error is emitted to `emitError` and a null type is
/// returned.
static RankedTensorType
getChecked(function_ref<InFlightDiagnostic()> emitError,
ArrayRef<int64_t> shape, Type elementType);
/// Verify the construction of a ranked tensor type.
static LogicalResult verifyConstructionInvariants(Location loc,
ArrayRef<int64_t> shape,
Type elementType);
static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
ArrayRef<int64_t> shape, Type elementType);
ArrayRef<int64_t> getShape() const;
};
@ -256,20 +256,22 @@ class UnrankedTensorType
detail::UnrankedTensorTypeStorage> {
public:
using Base::Base;
using Base::getChecked;
/// Get or create a new UnrankedTensorType of the provided shape and element
/// type. Assumes the arguments define a well-formed type.
static UnrankedTensorType get(Type elementType);
/// Get or create a new UnrankedTensorType of the provided shape and element
/// type declared at the given, potentially unknown, location. If the
/// UnrankedTensorType defined by the arguments would be ill-formed, emit
/// errors and return a nullptr-wrapping type.
static UnrankedTensorType getChecked(Location location, Type elementType);
/// type. If the RankedTensorType defined by the arguments would be
/// ill-formed, an error is emitted to `emitError` and a null type is
/// returned.
static UnrankedTensorType
getChecked(function_ref<InFlightDiagnostic()> emitError, Type elementType);
/// Verify the construction of a unranked tensor type.
static LogicalResult verifyConstructionInvariants(Location loc,
Type elementType);
static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
Type elementType);
ArrayRef<int64_t> getShape() const { return llvm::None; }
};
@ -351,6 +353,7 @@ public:
};
using Base::Base;
using Base::getChecked;
/// Get or create a new MemRefType based on shape, element type, affine
/// map composition, and memory space. Assumes the arguments define a
@ -361,13 +364,11 @@ public:
unsigned memorySpace = 0);
/// Get or create a new MemRefType based on shape, element type, affine
/// map composition, and memory space declared at the given location.
/// If the location is unknown, the last argument should be an instance of
/// UnknownLoc. If the MemRefType defined by the arguments would be
/// ill-formed, emits errors (to the handler registered with the context or to
/// the error stream) and returns nullptr.
static MemRefType getChecked(Location location, ArrayRef<int64_t> shape,
Type elementType,
/// map composition, and memory space. If the MemRefType defined by the
/// arguments would be ill-formed, an error is emitted to `emitError` and a
/// null type is returned.
static MemRefType getChecked(function_ref<InFlightDiagnostic()> emitError,
ArrayRef<int64_t> shape, Type elementType,
ArrayRef<AffineMap> affineMapComposition,
unsigned memorySpace);
@ -386,11 +387,11 @@ public:
private:
/// Get or create a new MemRefType defined by the arguments. If the resulting
/// type would be ill-formed, return nullptr. If the location is provided,
/// emit detailed error messages.
/// type would be ill-formed, return nullptr.
static MemRefType getImpl(ArrayRef<int64_t> shape, Type elementType,
ArrayRef<AffineMap> affineMapComposition,
unsigned memorySpace, Optional<Location> location);
unsigned memorySpace,
function_ref<InFlightDiagnostic()> emitError);
using Base::getImpl;
};
@ -404,22 +405,23 @@ class UnrankedMemRefType
detail::UnrankedMemRefTypeStorage> {
public:
using Base::Base;
using Base::getChecked;
/// Get or create a new UnrankedMemRefType of the provided element
/// type and memory space
static UnrankedMemRefType get(Type elementType, unsigned memorySpace);
/// Get or create a new UnrankedMemRefType of the provided element
/// type and memory space declared at the given, potentially unknown,
/// location. If the UnrankedMemRefType defined by the arguments would be
/// ill-formed, emit errors and return a nullptr-wrapping type.
static UnrankedMemRefType getChecked(Location location, Type elementType,
unsigned memorySpace);
/// type and memory space. If the UnrankedMemRefType defined by the arguments
/// would be ill-formed, an error is emitted to `emitError` and a null type is
/// returned.
static UnrankedMemRefType
getChecked(function_ref<InFlightDiagnostic()> emitError, Type elementType,
unsigned memorySpace);
/// Verify the construction of a unranked memref type.
static LogicalResult verifyConstructionInvariants(Location loc,
Type elementType,
unsigned memorySpace);
static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
Type elementType, unsigned memorySpace);
ArrayRef<int64_t> getShape() const { return llvm::None; }
};

View File

@ -52,13 +52,11 @@ def Builtin_Complex : Builtin_Type<"Complex"> {
let parameters = (ins "Type":$elementType);
let builders = [
TypeBuilderWithInferredContext<(ins "Type":$elementType), [{
return Base::get(elementType.getContext(), elementType);
}], [{
return Base::getChecked($_loc, elementType);
return $_get(elementType.getContext(), elementType);
}]>
];
let skipDefaultBuilders = 1;
let genVerifyInvariantsDecl = 1;
let genVerifyDecl = 1;
}
//===----------------------------------------------------------------------===//
@ -137,7 +135,7 @@ def Builtin_Function : Builtin_Type<"Function"> {
let parameters = (ins "ArrayRef<Type>":$inputs, "ArrayRef<Type>":$results);
let builders = [
TypeBuilder<(ins CArg<"TypeRange">:$inputs, CArg<"TypeRange">:$results), [{
return Base::get($_ctxt, inputs, results);
return $_get($_ctxt, inputs, results);
}]>
];
let skipDefaultBuilders = 1;
@ -225,7 +223,7 @@ def Builtin_Integer : Builtin_Type<"Integer"> {
// memory.
let genStorageClass = 0;
let skipDefaultBuilders = 1;
let genVerifyInvariantsDecl = 1;
let genVerifyDecl = 1;
let extraClassDeclaration = [{
/// Signedness semantics.
enum SignednessSemantics : uint32_t {
@ -295,7 +293,7 @@ def Builtin_Opaque : Builtin_Type<"Opaque"> {
"Identifier":$dialectNamespace,
StringRefParameter<"">:$typeData
);
let genVerifyInvariantsDecl = 1;
let genVerifyDecl = 1;
}
//===----------------------------------------------------------------------===//
@ -334,10 +332,10 @@ def Builtin_Tuple : Builtin_Type<"Tuple"> {
let parameters = (ins "ArrayRef<Type>":$types);
let builders = [
TypeBuilder<(ins "TypeRange":$elementTypes), [{
return Base::get($_ctxt, elementTypes);
return $_get($_ctxt, elementTypes);
}]>,
TypeBuilder<(ins), [{
return Base::get($_ctxt, TypeRange());
return $_get($_ctxt, TypeRange());
}]>
];
let skipDefaultBuilders = 1;

View File

@ -2492,15 +2492,21 @@ def replaceWithValue;
//
// If an empty string is passed in for `body`, then *only* the builder
// declaration will be generated; this provides a way to define complicated
// builders entirely in C++.
// builders entirely in C++. If a `body` string is provided, the `Base::get`
// method should be invoked using `$_get`, e.g.:
//
// `checkedBody` is similar to `body`, but is the code block used when
// generating a `getChecked` method.
class TypeBuilder<dag parameters, code bodyCode = "",
code checkedBodyCode = ""> {
// ```
// TypeBuilder<(ins "int":$integerArg, CArg<"float", "3.0f">:$floatArg), [{
// return $_get($_ctxt, integerArg, floatArg);
// }]>
// ```
//
// This is necessary because the `body` is also used to generate `getChecked`
// methods, which have a different underlying `Base::get*` call.
//
class TypeBuilder<dag parameters, code bodyCode = ""> {
dag dagParams = parameters;
code body = bodyCode;
code checkedBody = checkedBodyCode;
// The context parameter can be inferred from one of the other parameters and
// is not implicitly added to the parameter list.
@ -2510,10 +2516,8 @@ class TypeBuilder<dag parameters, code bodyCode = "",
// A class of TypeBuilder that is able to infer the MLIRContext parameter from
// one of the other builder parameters. Instances of this builder do not have
// `MLIRContext *` implicitly added to the parameter list.
class TypeBuilderWithInferredContext<dag parameters, code bodyCode = "",
code checkedBodyCode = "">
class TypeBuilderWithInferredContext<dag parameters, code bodyCode = "">
: TypeBuilder<parameters, bodyCode> {
code checkedBody = checkedBodyCode;
let hasInferredContextParam = 1;
}
@ -2590,9 +2594,8 @@ class TypeDef<Dialect dialect, string name,
// Avoid generating default get/getChecked functions. Custom get methods must
// be provided.
bit skipDefaultBuilders = 0;
// Generate the verifyConstructionInvariants declaration and getChecked
// method.
bit genVerifyInvariantsDecl = 0;
// Generate the verify and getChecked methods.
bit genVerifyDecl = 0;
// Extra code to include in the class declaration.
code extraClassDeclaration = [{}];

View File

@ -17,16 +17,21 @@
#include "mlir/Support/LogicalResult.h"
#include "mlir/Support/StorageUniquer.h"
#include "mlir/Support/TypeID.h"
#include "llvm/ADT/FunctionExtras.h"
namespace mlir {
class AttributeStorage;
class InFlightDiagnostic;
class Location;
class MLIRContext;
namespace detail {
/// Utility method to generate a raw default location for use when checking the
/// construction invariants of a storage object. This is defined out-of-line to
/// avoid the need to include Location.h.
const AttributeStorage *generateUnknownStorageLocation(MLIRContext *ctx);
/// Utility method to generate a callback that can be used to generate a
/// diagnostic when checking the construction invariants of a storage object.
/// This is defined out-of-line to avoid the need to include Location.h.
llvm::unique_function<InFlightDiagnostic()>
getDefaultDiagnosticEmitFn(MLIRContext *ctx);
llvm::unique_function<InFlightDiagnostic()>
getDefaultDiagnosticEmitFn(const Location &loc);
//===----------------------------------------------------------------------===//
// StorageUserTraitBase
@ -88,20 +93,30 @@ public:
template <typename... Args>
static ConcreteT get(MLIRContext *ctx, Args... args) {
// Ensure that the invariants are correct for construction.
assert(succeeded(ConcreteT::verifyConstructionInvariants(
generateUnknownStorageLocation(ctx), args...)));
assert(
succeeded(ConcreteT::verify(getDefaultDiagnosticEmitFn(ctx), args...)));
return UniquerT::template get<ConcreteT>(ctx, args...);
}
/// Get or create a new ConcreteT instance within the ctx, defined at
/// the given, potentially unknown, location. If the arguments provided are
/// invalid then emit errors and return a null object.
template <typename LocationT, typename... Args>
static ConcreteT getChecked(LocationT loc, Args... args) {
/// invalid, errors are emitted using the provided location and a null object
/// is returned.
template <typename... Args>
static ConcreteT getChecked(const Location &loc, Args... args) {
return ConcreteT::getChecked(getDefaultDiagnosticEmitFn(loc), args...);
}
/// Get or create a new ConcreteT instance within the ctx. If the arguments
/// provided are invalid, errors are emitted using the provided `emitError`
/// and a null object is returned.
template <typename... Args>
static ConcreteT getChecked(function_ref<InFlightDiagnostic()> emitErrorFn,
MLIRContext *ctx, Args... args) {
// If the construction invariants fail then we return a null attribute.
if (failed(ConcreteT::verifyConstructionInvariants(loc, args...)))
if (failed(ConcreteT::verify(emitErrorFn, args...)))
return ConcreteT();
return UniquerT::template get<ConcreteT>(loc.getContext(), args...);
return UniquerT::template get<ConcreteT>(ctx, args...);
}
/// Get an instance of the concrete type from a void pointer.
@ -119,8 +134,7 @@ protected:
}
/// Default implementation that just returns success.
template <typename... Args>
static LogicalResult verifyConstructionInvariants(Args... args) {
template <typename... Args> static LogicalResult verify(Args... args) {
return success();
}

View File

@ -32,8 +32,9 @@ namespace mlir {
/// Derived type classes are expected to implement several required
/// implementation hooks:
/// * Optional:
/// - static LogicalResult verifyConstructionInvariants(Location loc,
/// Args... args)
/// - static LogicalResult verify(
/// function_ref<InFlightDiagnostic()> emitError,
/// Args... args)
/// * This method is invoked when calling the 'TypeBase::get/getChecked'
/// methods to ensure that the arguments passed in are valid to construct
/// a type instance with.
@ -92,8 +93,7 @@ public:
bool operator!() const { return impl == nullptr; }
template <typename U> bool isa() const;
template <typename First, typename Second, typename... Rest>
bool isa() const;
template <typename First, typename Second, typename... Rest> bool isa() const;
template <typename U> U dyn_cast() const;
template <typename U> U dyn_cast_or_null() const;
template <typename U> U cast() const;

View File

@ -36,10 +36,6 @@ class TypeBuilder : public Builder {
public:
using Builder::Builder;
/// Return an optional code body used for the `getChecked` variant of this
/// builder.
Optional<StringRef> getCheckedBody() const;
/// Returns true if this builder is able to infer the MLIRContext parameter.
bool hasInferredContextParameter() const;
};
@ -106,9 +102,9 @@ public:
// generated.
bool genAccessors() const;
// Return true if we need to generate the verifyConstructionInvariants
// declaration and getChecked method.
bool genVerifyInvariantsDecl() const;
// Return true if we need to generate the verify declaration and getChecked
// method.
bool genVerifyDecl() const;
// Returns the dialects extra class declaration code.
Optional<StringRef> getExtraDecls() const;

View File

@ -1466,8 +1466,7 @@ namespace {
/// CRTP base class for Python MLIR values that subclass Value and should be
/// castable from it. The value hierarchy is one level deep and is not supposed
/// to accommodate other levels unless core MLIR changes.
template <typename DerivedTy>
class PyConcreteValue : public PyValue {
template <typename DerivedTy> class PyConcreteValue : public PyValue {
public:
// Derived classes must define statics for:
// IsAFunctionTy isaFunction
@ -1868,7 +1867,7 @@ public:
c.def_static(
"get",
[](PyType &type, double value, DefaultingPyLocation loc) {
MlirAttribute attr = mlirFloatAttrDoubleGetChecked(type, value, loc);
MlirAttribute attr = mlirFloatAttrDoubleGetChecked(loc, type, value);
// TODO: Rework error reporting once diagnostic engine is exposed
// in C API.
if (mlirAttributeIsNull(attr)) {
@ -2765,8 +2764,8 @@ public:
"get",
[](std::vector<int64_t> shape, PyType &elementType,
DefaultingPyLocation loc) {
MlirType t = mlirVectorTypeGetChecked(shape.size(), shape.data(),
elementType, loc);
MlirType t = mlirVectorTypeGetChecked(loc, shape.size(), shape.data(),
elementType);
// TODO: Rework error reporting once diagnostic engine is exposed
// in C API.
if (mlirTypeIsNull(t)) {
@ -2797,7 +2796,7 @@ public:
[](std::vector<int64_t> shape, PyType &elementType,
DefaultingPyLocation loc) {
MlirType t = mlirRankedTensorTypeGetChecked(
shape.size(), shape.data(), elementType, loc);
loc, shape.size(), shape.data(), elementType);
// TODO: Rework error reporting once diagnostic engine is exposed
// in C API.
if (mlirTypeIsNull(t)) {
@ -2828,7 +2827,7 @@ public:
c.def_static(
"get",
[](PyType &elementType, DefaultingPyLocation loc) {
MlirType t = mlirUnrankedTensorTypeGetChecked(elementType, loc);
MlirType t = mlirUnrankedTensorTypeGetChecked(loc, elementType);
// TODO: Rework error reporting once diagnostic engine is exposed
// in C API.
if (mlirTypeIsNull(t)) {
@ -2869,9 +2868,9 @@ public:
for (PyAffineMap &map : layout)
maps.push_back(map);
MlirType t = mlirMemRefTypeGetChecked(elementType, shape.size(),
MlirType t = mlirMemRefTypeGetChecked(loc, elementType, shape.size(),
shape.data(), maps.size(),
maps.data(), memorySpace, loc);
maps.data(), memorySpace);
// TODO: Rework error reporting once diagnostic engine is exposed
// in C API.
if (mlirTypeIsNull(t)) {
@ -2948,7 +2947,7 @@ public:
[](PyType &elementType, unsigned memorySpace,
DefaultingPyLocation loc) {
MlirType t =
mlirUnrankedMemRefTypeGetChecked(elementType, memorySpace, loc);
mlirUnrankedMemRefTypeGetChecked(loc, elementType, memorySpace);
// TODO: Rework error reporting once diagnostic engine is exposed
// in C API.
if (mlirTypeIsNull(t)) {

View File

@ -103,9 +103,9 @@ MlirAttribute mlirFloatAttrDoubleGet(MlirContext ctx, MlirType type,
return wrap(FloatAttr::get(unwrap(type), value));
}
MlirAttribute mlirFloatAttrDoubleGetChecked(MlirType type, double value,
MlirLocation loc) {
return wrap(FloatAttr::getChecked(unwrap(type), value, unwrap(loc)));
MlirAttribute mlirFloatAttrDoubleGetChecked(MlirLocation loc, MlirType type,
double value) {
return wrap(FloatAttr::getChecked(unwrap(loc), unwrap(type), value));
}
double mlirFloatAttrGetValueDouble(MlirAttribute attr) {

View File

@ -169,8 +169,8 @@ MlirType mlirVectorTypeGet(intptr_t rank, const int64_t *shape,
unwrap(elementType)));
}
MlirType mlirVectorTypeGetChecked(intptr_t rank, const int64_t *shape,
MlirType elementType, MlirLocation loc) {
MlirType mlirVectorTypeGetChecked(MlirLocation loc, intptr_t rank,
const int64_t *shape, MlirType elementType) {
return wrap(VectorType::getChecked(
unwrap(loc), llvm::makeArrayRef(shape, static_cast<size_t>(rank)),
unwrap(elementType)));
@ -197,9 +197,9 @@ MlirType mlirRankedTensorTypeGet(intptr_t rank, const int64_t *shape,
unwrap(elementType)));
}
MlirType mlirRankedTensorTypeGetChecked(intptr_t rank, const int64_t *shape,
MlirType elementType,
MlirLocation loc) {
MlirType mlirRankedTensorTypeGetChecked(MlirLocation loc, intptr_t rank,
const int64_t *shape,
MlirType elementType) {
return wrap(RankedTensorType::getChecked(
unwrap(loc), llvm::makeArrayRef(shape, static_cast<size_t>(rank)),
unwrap(elementType)));
@ -209,8 +209,8 @@ MlirType mlirUnrankedTensorTypeGet(MlirType elementType) {
return wrap(UnrankedTensorType::get(unwrap(elementType)));
}
MlirType mlirUnrankedTensorTypeGetChecked(MlirType elementType,
MlirLocation loc) {
MlirType mlirUnrankedTensorTypeGetChecked(MlirLocation loc,
MlirType elementType) {
return wrap(UnrankedTensorType::getChecked(unwrap(loc), unwrap(elementType)));
}
@ -231,10 +231,11 @@ MlirType mlirMemRefTypeGet(MlirType elementType, intptr_t rank,
unwrap(elementType), maps, memorySpace));
}
MlirType mlirMemRefTypeGetChecked(MlirType elementType, intptr_t rank,
const int64_t *shape, intptr_t numMaps,
MlirType mlirMemRefTypeGetChecked(MlirLocation loc, MlirType elementType,
intptr_t rank, const int64_t *shape,
intptr_t numMaps,
MlirAffineMap const *affineMaps,
unsigned memorySpace, MlirLocation loc) {
unsigned memorySpace) {
SmallVector<AffineMap, 1> maps;
(void)unwrapList(numMaps, affineMaps, maps);
return wrap(MemRefType::getChecked(
@ -250,10 +251,10 @@ MlirType mlirMemRefTypeContiguousGet(MlirType elementType, intptr_t rank,
unwrap(elementType), llvm::None, memorySpace));
}
MlirType mlirMemRefTypeContiguousGetChecked(MlirType elementType, intptr_t rank,
MlirType mlirMemRefTypeContiguousGetChecked(MlirLocation loc,
MlirType elementType, intptr_t rank,
const int64_t *shape,
unsigned memorySpace,
MlirLocation loc) {
unsigned memorySpace) {
return wrap(MemRefType::getChecked(
unwrap(loc), llvm::makeArrayRef(shape, static_cast<size_t>(rank)),
unwrap(elementType), llvm::None, memorySpace));
@ -280,9 +281,9 @@ MlirType mlirUnrankedMemRefTypeGet(MlirType elementType, unsigned memorySpace) {
return wrap(UnrankedMemRefType::get(unwrap(elementType), memorySpace));
}
MlirType mlirUnrankedMemRefTypeGetChecked(MlirType elementType,
unsigned memorySpace,
MlirLocation loc) {
MlirType mlirUnrankedMemRefTypeGetChecked(MlirLocation loc,
MlirType elementType,
unsigned memorySpace) {
return wrap(UnrankedMemRefType::getChecked(unwrap(loc), unwrap(elementType),
memorySpace));
}

View File

@ -187,7 +187,7 @@ static LLVMFunctionType parseFunctionType(DialectAsmParser &parser) {
// Function type without arguments.
if (succeeded(parser.parseOptionalRParen())) {
if (succeeded(parser.parseGreater()))
return LLVMFunctionType::getChecked(loc, returnType, {},
return LLVMFunctionType::getChecked(loc, returnType, llvm::None,
/*isVarArg=*/false);
return LLVMFunctionType();
}
@ -345,7 +345,8 @@ static LLVMStructType parseStructType(DialectAsmParser &parser) {
if (knownStructNames.count(name)) {
if (failed(parser.parseGreater()))
return LLVMStructType();
return LLVMStructType::getIdentifiedChecked(loc, name);
return LLVMStructType::getIdentifiedChecked(
[loc] { return emitError(loc); }, loc.getContext(), name);
}
if (failed(parser.parseComma()))
return LLVMStructType();
@ -359,7 +360,8 @@ static LLVMStructType parseStructType(DialectAsmParser &parser) {
LLVMStructType();
if (failed(parser.parseGreater()))
return LLVMStructType();
auto type = LLVMStructType::getOpaqueChecked(loc, name);
auto type = LLVMStructType::getOpaqueChecked(
[loc] { return emitError(loc); }, loc.getContext(), name);
if (!type.isOpaque()) {
parser.emitError(kwLoc, "redeclaring defined struct as opaque");
return LLVMStructType();
@ -377,8 +379,10 @@ static LLVMStructType parseStructType(DialectAsmParser &parser) {
if (failed(parser.parseGreater()))
return LLVMStructType();
if (!isIdentified)
return LLVMStructType::getLiteralChecked(loc, {}, isPacked);
auto type = LLVMStructType::getIdentifiedChecked(loc, name);
return LLVMStructType::getLiteralChecked([loc] { return emitError(loc); },
loc.getContext(), {}, isPacked);
auto type = LLVMStructType::getIdentifiedChecked(
[loc] { return emitError(loc); }, loc.getContext(), name);
return trySetStructBody(type, {}, isPacked, parser, kwLoc);
}
@ -402,8 +406,10 @@ static LLVMStructType parseStructType(DialectAsmParser &parser) {
// Construct the struct with body.
if (!isIdentified)
return LLVMStructType::getLiteralChecked(loc, subtypes, isPacked);
auto type = LLVMStructType::getIdentifiedChecked(loc, name);
return LLVMStructType::getLiteralChecked(
[loc] { return emitError(loc); }, loc.getContext(), subtypes, isPacked);
auto type = LLVMStructType::getIdentifiedChecked(
[loc] { return emitError(loc); }, loc.getContext(), name);
return trySetStructBody(type, subtypes, isPacked, parser, subtypesLoc);
}

View File

@ -39,10 +39,12 @@ LLVMArrayType LLVMArrayType::get(Type elementType, unsigned numElements) {
return Base::get(elementType.getContext(), elementType, numElements);
}
LLVMArrayType LLVMArrayType::getChecked(Location loc, Type elementType,
unsigned numElements) {
LLVMArrayType
LLVMArrayType::getChecked(function_ref<InFlightDiagnostic()> emitError,
Type elementType, unsigned numElements) {
assert(elementType && "expected non-null subtype");
return Base::getChecked(loc, elementType, numElements);
return Base::getChecked(emitError, elementType.getContext(), elementType,
numElements);
}
Type LLVMArrayType::getElementType() { return getImpl()->elementType; }
@ -50,10 +52,10 @@ Type LLVMArrayType::getElementType() { return getImpl()->elementType; }
unsigned LLVMArrayType::getNumElements() { return getImpl()->numElements; }
LogicalResult
LLVMArrayType::verifyConstructionInvariants(Location loc, Type elementType,
unsigned numElements) {
LLVMArrayType::verify(function_ref<InFlightDiagnostic()> emitError,
Type elementType, unsigned numElements) {
if (!isValidElementType(elementType))
return emitError(loc, "invalid array element type: ") << elementType;
return emitError() << "invalid array element type: " << elementType;
return success();
}
@ -75,11 +77,13 @@ LLVMFunctionType LLVMFunctionType::get(Type result, ArrayRef<Type> arguments,
return Base::get(result.getContext(), result, arguments, isVarArg);
}
LLVMFunctionType LLVMFunctionType::getChecked(Location loc, Type result,
ArrayRef<Type> arguments,
bool isVarArg) {
LLVMFunctionType
LLVMFunctionType::getChecked(function_ref<InFlightDiagnostic()> emitError,
Type result, ArrayRef<Type> arguments,
bool isVarArg) {
assert(result && "expected non-null result");
return Base::getChecked(loc, result, arguments, isVarArg);
return Base::getChecked(emitError, result.getContext(), result, arguments,
isVarArg);
}
Type LLVMFunctionType::getReturnType() { return getImpl()->getReturnType(); }
@ -99,14 +103,14 @@ ArrayRef<Type> LLVMFunctionType::getParams() {
}
LogicalResult
LLVMFunctionType::verifyConstructionInvariants(Location loc, Type result,
ArrayRef<Type> arguments, bool) {
LLVMFunctionType::verify(function_ref<InFlightDiagnostic()> emitError,
Type result, ArrayRef<Type> arguments, bool) {
if (!isValidResultType(result))
return emitError(loc, "invalid function result type: ") << result;
return emitError() << "invalid function result type: " << result;
for (Type arg : arguments)
if (!isValidArgumentType(arg))
return emitError(loc, "invalid function argument type: ") << arg;
return emitError() << "invalid function argument type: " << arg;
return success();
}
@ -125,20 +129,22 @@ LLVMPointerType LLVMPointerType::get(Type pointee, unsigned addressSpace) {
return Base::get(pointee.getContext(), pointee, addressSpace);
}
LLVMPointerType LLVMPointerType::getChecked(Location loc, Type pointee,
unsigned addressSpace) {
return Base::getChecked(loc, pointee, addressSpace);
LLVMPointerType
LLVMPointerType::getChecked(function_ref<InFlightDiagnostic()> emitError,
Type pointee, unsigned addressSpace) {
return Base::getChecked(emitError, pointee.getContext(), pointee,
addressSpace);
}
Type LLVMPointerType::getElementType() { return getImpl()->pointeeType; }
unsigned LLVMPointerType::getAddressSpace() { return getImpl()->addressSpace; }
LogicalResult LLVMPointerType::verifyConstructionInvariants(Location loc,
Type pointee,
unsigned) {
LogicalResult
LLVMPointerType::verify(function_ref<InFlightDiagnostic()> emitError,
Type pointee, unsigned) {
if (!isValidElementType(pointee))
return emitError(loc, "invalid pointer element type: ") << pointee;
return emitError() << "invalid pointer element type: " << pointee;
return success();
}
@ -156,9 +162,10 @@ LLVMStructType LLVMStructType::getIdentified(MLIRContext *context,
return Base::get(context, name, /*opaque=*/false);
}
LLVMStructType LLVMStructType::getIdentifiedChecked(Location loc,
StringRef name) {
return Base::getChecked(loc, name, /*opaque=*/false);
LLVMStructType LLVMStructType::getIdentifiedChecked(
function_ref<InFlightDiagnostic()> emitError, MLIRContext *context,
StringRef name) {
return Base::getChecked(emitError, context, name, /*opaque=*/false);
}
LLVMStructType LLVMStructType::getNewIdentified(MLIRContext *context,
@ -183,18 +190,21 @@ LLVMStructType LLVMStructType::getLiteral(MLIRContext *context,
return Base::get(context, types, isPacked);
}
LLVMStructType LLVMStructType::getLiteralChecked(Location loc,
ArrayRef<Type> types,
bool isPacked) {
return Base::getChecked(loc, types, isPacked);
LLVMStructType
LLVMStructType::getLiteralChecked(function_ref<InFlightDiagnostic()> emitError,
MLIRContext *context, ArrayRef<Type> types,
bool isPacked) {
return Base::getChecked(emitError, context, types, isPacked);
}
LLVMStructType LLVMStructType::getOpaque(StringRef name, MLIRContext *context) {
return Base::get(context, name, /*opaque=*/true);
}
LLVMStructType LLVMStructType::getOpaqueChecked(Location loc, StringRef name) {
return Base::getChecked(loc, name, /*opaque=*/true);
LLVMStructType
LLVMStructType::getOpaqueChecked(function_ref<InFlightDiagnostic()> emitError,
MLIRContext *context, StringRef name) {
return Base::getChecked(emitError, context, name, /*opaque=*/true);
}
LogicalResult LLVMStructType::setBody(ArrayRef<Type> types, bool isPacked) {
@ -217,17 +227,17 @@ ArrayRef<Type> LLVMStructType::getBody() {
: getImpl()->getTypeList();
}
LogicalResult LLVMStructType::verifyConstructionInvariants(Location, StringRef,
bool) {
LogicalResult LLVMStructType::verify(function_ref<InFlightDiagnostic()>,
StringRef, bool) {
return success();
}
LogicalResult LLVMStructType::verifyConstructionInvariants(Location loc,
ArrayRef<Type> types,
bool) {
LogicalResult
LLVMStructType::verify(function_ref<InFlightDiagnostic()> emitError,
ArrayRef<Type> types, bool) {
for (Type t : types)
if (!isValidElementType(t))
return emitError(loc, "invalid LLVM structure element type: ") << t;
return emitError() << "invalid LLVM structure element type: " << t;
return success();
}
@ -238,14 +248,14 @@ LogicalResult LLVMStructType::verifyConstructionInvariants(Location loc,
/// Verifies that the type about to be constructed is well-formed.
template <typename VecTy>
static LogicalResult verifyVectorConstructionInvariants(Location loc,
Type elementType,
unsigned numElements) {
static LogicalResult
verifyVectorConstructionInvariants(function_ref<InFlightDiagnostic()> emitError,
Type elementType, unsigned numElements) {
if (numElements == 0)
return emitError(loc, "the number of vector elements must be positive");
return emitError() << "the number of vector elements must be positive";
if (!VecTy::isValidElementType(elementType))
return emitError(loc, "invalid vector element type");
return emitError() << "invalid vector element type";
return success();
}
@ -256,11 +266,12 @@ LLVMFixedVectorType LLVMFixedVectorType::get(Type elementType,
return Base::get(elementType.getContext(), elementType, numElements);
}
LLVMFixedVectorType LLVMFixedVectorType::getChecked(Location loc,
Type elementType,
unsigned numElements) {
LLVMFixedVectorType
LLVMFixedVectorType::getChecked(function_ref<InFlightDiagnostic()> emitError,
Type elementType, unsigned numElements) {
assert(elementType && "expected non-null subtype");
return Base::getChecked(loc, elementType, numElements);
return Base::getChecked(emitError, elementType.getContext(), elementType,
numElements);
}
Type LLVMFixedVectorType::getElementType() {
@ -275,10 +286,11 @@ bool LLVMFixedVectorType::isValidElementType(Type type) {
return type.isa<LLVMPointerType, LLVMPPCFP128Type>();
}
LogicalResult LLVMFixedVectorType::verifyConstructionInvariants(
Location loc, Type elementType, unsigned numElements) {
LogicalResult
LLVMFixedVectorType::verify(function_ref<InFlightDiagnostic()> emitError,
Type elementType, unsigned numElements) {
return verifyVectorConstructionInvariants<LLVMFixedVectorType>(
loc, elementType, numElements);
emitError, elementType, numElements);
}
//===----------------------------------------------------------------------===//
@ -292,10 +304,11 @@ LLVMScalableVectorType LLVMScalableVectorType::get(Type elementType,
}
LLVMScalableVectorType
LLVMScalableVectorType::getChecked(Location loc, Type elementType,
unsigned minNumElements) {
LLVMScalableVectorType::getChecked(function_ref<InFlightDiagnostic()> emitError,
Type elementType, unsigned minNumElements) {
assert(elementType && "expected non-null subtype");
return Base::getChecked(loc, elementType, minNumElements);
return Base::getChecked(emitError, elementType.getContext(), elementType,
minNumElements);
}
Type LLVMScalableVectorType::getElementType() {
@ -313,10 +326,11 @@ bool LLVMScalableVectorType::isValidElementType(Type type) {
return isCompatibleFloatingPointType(type) || type.isa<LLVMPointerType>();
}
LogicalResult LLVMScalableVectorType::verifyConstructionInvariants(
Location loc, Type elementType, unsigned numElements) {
LogicalResult
LLVMScalableVectorType::verify(function_ref<InFlightDiagnostic()> emitError,
Type elementType, unsigned numElements) {
return verifyVectorConstructionInvariants<LLVMScalableVectorType>(
loc, elementType, numElements);
emitError, elementType, numElements);
}
//===----------------------------------------------------------------------===//

View File

@ -28,20 +28,21 @@ bool QuantizedType::classof(Type type) {
return llvm::isa<QuantizationDialect>(type.getDialect());
}
LogicalResult QuantizedType::verifyConstructionInvariants(
Location loc, unsigned flags, Type storageType, Type expressedType,
int64_t storageTypeMin, int64_t storageTypeMax) {
LogicalResult
QuantizedType::verify(function_ref<InFlightDiagnostic()> emitError,
unsigned flags, Type storageType, Type expressedType,
int64_t storageTypeMin, int64_t storageTypeMax) {
// Verify that the storage type is integral.
// This restriction may be lifted at some point in favor of using bf16
// or f16 as exact representations on hardware where that is advantageous.
auto intStorageType = storageType.dyn_cast<IntegerType>();
if (!intStorageType)
return emitError(loc, "storage type must be integral");
return emitError() << "storage type must be integral";
unsigned integralWidth = intStorageType.getWidth();
// Verify storage width.
if (integralWidth == 0 || integralWidth > MaxStorageBits)
return emitError(loc, "illegal storage type size: ") << integralWidth;
return emitError() << "illegal storage type size: " << integralWidth;
// Verify storageTypeMin and storageTypeMax.
bool isSigned =
@ -53,8 +54,8 @@ LogicalResult QuantizedType::verifyConstructionInvariants(
if (storageTypeMax - storageTypeMin <= 0 ||
storageTypeMin < defaultIntegerMin ||
storageTypeMax > defaultIntegerMax) {
return emitError(loc, "illegal storage min and storage max: (")
<< storageTypeMin << ":" << storageTypeMax << ")";
return emitError() << "illegal storage min and storage max: ("
<< storageTypeMin << ":" << storageTypeMax << ")";
}
return success();
}
@ -208,21 +209,22 @@ AnyQuantizedType AnyQuantizedType::get(unsigned flags, Type storageType,
storageTypeMin, storageTypeMax);
}
AnyQuantizedType AnyQuantizedType::getChecked(unsigned flags, Type storageType,
Type expressedType,
int64_t storageTypeMin,
int64_t storageTypeMax,
Location location) {
return Base::getChecked(location, flags, storageType, expressedType,
storageTypeMin, storageTypeMax);
AnyQuantizedType
AnyQuantizedType::getChecked(function_ref<InFlightDiagnostic()> emitError,
unsigned flags, Type storageType,
Type expressedType, int64_t storageTypeMin,
int64_t storageTypeMax) {
return Base::getChecked(emitError, storageType.getContext(), flags,
storageType, expressedType, storageTypeMin,
storageTypeMax);
}
LogicalResult AnyQuantizedType::verifyConstructionInvariants(
Location loc, unsigned flags, Type storageType, Type expressedType,
int64_t storageTypeMin, int64_t storageTypeMax) {
if (failed(QuantizedType::verifyConstructionInvariants(
loc, flags, storageType, expressedType, storageTypeMin,
storageTypeMax))) {
LogicalResult
AnyQuantizedType::verify(function_ref<InFlightDiagnostic()> emitError,
unsigned flags, Type storageType, Type expressedType,
int64_t storageTypeMin, int64_t storageTypeMax) {
if (failed(QuantizedType::verify(emitError, flags, storageType, expressedType,
storageTypeMin, storageTypeMax))) {
return failure();
}
@ -230,7 +232,7 @@ LogicalResult AnyQuantizedType::verifyConstructionInvariants(
// If this restriction is ever eliminated, the parser/printer must be
// extended.
if (expressedType && !expressedType.isa<FloatType>())
return emitError(loc, "expressed type must be floating point");
return emitError() << "expressed type must be floating point";
return success();
}
@ -244,39 +246,38 @@ UniformQuantizedType UniformQuantizedType::get(unsigned flags, Type storageType,
scale, zeroPoint, storageTypeMin, storageTypeMax);
}
UniformQuantizedType
UniformQuantizedType::getChecked(unsigned flags, Type storageType,
Type expressedType, double scale,
int64_t zeroPoint, int64_t storageTypeMin,
int64_t storageTypeMax, Location location) {
return Base::getChecked(location, flags, storageType, expressedType, scale,
zeroPoint, storageTypeMin, storageTypeMax);
UniformQuantizedType UniformQuantizedType::getChecked(
function_ref<InFlightDiagnostic()> emitError, unsigned flags,
Type storageType, Type expressedType, double scale, int64_t zeroPoint,
int64_t storageTypeMin, int64_t storageTypeMax) {
return Base::getChecked(emitError, storageType.getContext(), flags,
storageType, expressedType, scale, zeroPoint,
storageTypeMin, storageTypeMax);
}
LogicalResult UniformQuantizedType::verifyConstructionInvariants(
Location loc, unsigned flags, Type storageType, Type expressedType,
double scale, int64_t zeroPoint, int64_t storageTypeMin,
int64_t storageTypeMax) {
if (failed(QuantizedType::verifyConstructionInvariants(
loc, flags, storageType, expressedType, storageTypeMin,
storageTypeMax))) {
LogicalResult UniformQuantizedType::verify(
function_ref<InFlightDiagnostic()> emitError, unsigned flags,
Type storageType, Type expressedType, double scale, int64_t zeroPoint,
int64_t storageTypeMin, int64_t storageTypeMax) {
if (failed(QuantizedType::verify(emitError, flags, storageType, expressedType,
storageTypeMin, storageTypeMax))) {
return failure();
}
// Uniform quantization requires fully expressed parameters, including
// expressed type.
if (!expressedType)
return emitError(loc, "uniform quantization requires expressed type");
return emitError() << "uniform quantization requires expressed type";
// Verify that the expressed type is floating point.
// If this restriction is ever eliminated, the parser/printer must be
// extended.
if (!expressedType.isa<FloatType>())
return emitError(loc, "expressed type must be floating point");
return emitError() << "expressed type must be floating point";
// Verify scale.
if (scale <= 0.0 || std::isinf(scale) || std::isnan(scale))
return emitError(loc, "illegal scale: ") << scale;
return emitError() << "illegal scale: " << scale;
return success();
}
@ -298,46 +299,45 @@ UniformQuantizedPerAxisType UniformQuantizedPerAxisType::get(
}
UniformQuantizedPerAxisType UniformQuantizedPerAxisType::getChecked(
unsigned flags, Type storageType, Type expressedType,
ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints,
int32_t quantizedDimension, int64_t storageTypeMin, int64_t storageTypeMax,
Location location) {
return Base::getChecked(location, flags, storageType, expressedType, scales,
zeroPoints, quantizedDimension, storageTypeMin,
storageTypeMax);
function_ref<InFlightDiagnostic()> emitError, unsigned flags,
Type storageType, Type expressedType, ArrayRef<double> scales,
ArrayRef<int64_t> zeroPoints, int32_t quantizedDimension,
int64_t storageTypeMin, int64_t storageTypeMax) {
return Base::getChecked(emitError, storageType.getContext(), flags,
storageType, expressedType, scales, zeroPoints,
quantizedDimension, storageTypeMin, storageTypeMax);
}
LogicalResult UniformQuantizedPerAxisType::verifyConstructionInvariants(
Location loc, unsigned flags, Type storageType, Type expressedType,
ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints,
int32_t quantizedDimension, int64_t storageTypeMin,
int64_t storageTypeMax) {
if (failed(QuantizedType::verifyConstructionInvariants(
loc, flags, storageType, expressedType, storageTypeMin,
storageTypeMax))) {
LogicalResult UniformQuantizedPerAxisType::verify(
function_ref<InFlightDiagnostic()> emitError, unsigned flags,
Type storageType, Type expressedType, ArrayRef<double> scales,
ArrayRef<int64_t> zeroPoints, int32_t quantizedDimension,
int64_t storageTypeMin, int64_t storageTypeMax) {
if (failed(QuantizedType::verify(emitError, flags, storageType, expressedType,
storageTypeMin, storageTypeMax))) {
return failure();
}
// Uniform quantization requires fully expressed parameters, including
// expressed type.
if (!expressedType)
return emitError(loc, "uniform quantization requires expressed type");
return emitError() << "uniform quantization requires expressed type";
// Verify that the expressed type is floating point.
// If this restriction is ever eliminated, the parser/printer must be
// extended.
if (!expressedType.isa<FloatType>())
return emitError(loc, "expressed type must be floating point");
return emitError() << "expressed type must be floating point";
// Ensure that the number of scales and zeroPoints match.
if (scales.size() != zeroPoints.size())
return emitError(loc, "illegal number of scales and zeroPoints: ")
<< scales.size() << ", " << zeroPoints.size();
return emitError() << "illegal number of scales and zeroPoints: "
<< scales.size() << ", " << zeroPoints.size();
// Verify scale.
for (double scale : scales) {
if (scale <= 0.0 || std::isinf(scale) || std::isnan(scale))
return emitError(loc, "illegal scale: ") << scale;
return emitError() << "illegal scale: " << scale;
}
return success();
@ -360,22 +360,23 @@ CalibratedQuantizedType CalibratedQuantizedType::get(Type expressedType,
return Base::get(expressedType.getContext(), expressedType, min, max);
}
CalibratedQuantizedType CalibratedQuantizedType::getChecked(Type expressedType,
double min,
double max,
Location location) {
return Base::getChecked(location, expressedType, min, max);
CalibratedQuantizedType CalibratedQuantizedType::getChecked(
function_ref<InFlightDiagnostic()> emitError, Type expressedType,
double min, double max) {
return Base::getChecked(emitError, expressedType.getContext(), expressedType,
min, max);
}
LogicalResult CalibratedQuantizedType::verifyConstructionInvariants(
Location loc, Type expressedType, double min, double max) {
LogicalResult
CalibratedQuantizedType::verify(function_ref<InFlightDiagnostic()> emitError,
Type expressedType, double min, double max) {
// Verify that the expressed type is floating point.
// If this restriction is ever eliminated, the parser/printer must be
// extended.
if (!expressedType.isa<FloatType>())
return emitError(loc, "expressed type must be floating point");
return emitError() << "expressed type must be floating point";
if (max <= min)
return emitError(loc, "illegal min and max: (") << min << ":" << max << ")";
return emitError() << "illegal min and max: (" << min << ":" << max << ")";
return success();
}

View File

@ -155,8 +155,9 @@ static Type parseAnyType(DialectAsmParser &parser, Location loc) {
return nullptr;
}
return AnyQuantizedType::getChecked(typeFlags, storageType, expressedType,
storageTypeMin, storageTypeMax, loc);
return AnyQuantizedType::getChecked(loc, typeFlags, storageType,
expressedType, storageTypeMin,
storageTypeMax);
}
static ParseResult parseQuantParams(DialectAsmParser &parser, double &scale,
@ -279,13 +280,13 @@ static Type parseUniformType(DialectAsmParser &parser, Location loc) {
ArrayRef<double> scalesRef(scales.begin(), scales.end());
ArrayRef<int64_t> zeroPointsRef(zeroPoints.begin(), zeroPoints.end());
return UniformQuantizedPerAxisType::getChecked(
typeFlags, storageType, expressedType, scalesRef, zeroPointsRef,
quantizedDimension, storageTypeMin, storageTypeMax, loc);
loc, typeFlags, storageType, expressedType, scalesRef, zeroPointsRef,
quantizedDimension, storageTypeMin, storageTypeMax);
}
return UniformQuantizedType::getChecked(typeFlags, storageType, expressedType,
scales.front(), zeroPoints.front(),
storageTypeMin, storageTypeMax, loc);
return UniformQuantizedType::getChecked(
loc, typeFlags, storageType, expressedType, scales.front(),
zeroPoints.front(), storageTypeMin, storageTypeMax);
}
/// Parses an CalibratedQuantizedType.
@ -313,7 +314,7 @@ static Type parseCalibratedType(DialectAsmParser &parser, Location loc) {
return nullptr;
}
return CalibratedQuantizedType::getChecked(expressedType, min, max, loc);
return CalibratedQuantizedType::getChecked(loc, expressedType, min, max);
}
/// Parse a type registered to this dialect.

View File

@ -123,17 +123,17 @@ mlir::quant::fakeQuantAttrsToType(Location loc, unsigned numBits, double rmin,
// 0.0s, so the scale is set to 1.0 and the tensor can be quantized to zero
// points and dequantized to 0.0.
if (std::fabs(rmax - rmin) < std::numeric_limits<double>::epsilon()) {
return UniformQuantizedType::getChecked(flags, storageType, expressedType,
1.0, qmin, qmin, qmax, loc);
return UniformQuantizedType::getChecked(
loc, flags, storageType, expressedType, 1.0, qmin, qmin, qmax);
}
double scale;
int64_t nudgedZeroPoint;
getNudgedScaleAndZeroPoint(qmin, qmax, rmin, rmax, scale, nudgedZeroPoint);
return UniformQuantizedType::getChecked(flags, storageType, expressedType,
scale, nudgedZeroPoint, qmin, qmax,
loc);
return UniformQuantizedType::getChecked(loc, flags, storageType,
expressedType, scale, nudgedZeroPoint,
qmin, qmax);
}
UniformQuantizedPerAxisType mlir::quant::fakeQuantAttrsToType(
@ -179,6 +179,6 @@ UniformQuantizedPerAxisType mlir::quant::fakeQuantAttrsToType(
unsigned flags = isSigned ? QuantizationFlags::Signed : 0;
return UniformQuantizedPerAxisType::getChecked(
flags, storageType, expressedType, scales, zeroPoints, quantizedDimension,
qmin, qmax, loc);
loc, flags, storageType, expressedType, scales, zeroPoints,
quantizedDimension, qmin, qmax);
}

View File

@ -162,23 +162,23 @@ Optional<spirv::StorageClass> spirv::InterfaceVarABIAttr::getStorageClass() {
return llvm::None;
}
LogicalResult spirv::InterfaceVarABIAttr::verifyConstructionInvariants(
Location loc, IntegerAttr descriptorSet, IntegerAttr binding,
IntegerAttr storageClass) {
LogicalResult spirv::InterfaceVarABIAttr::verify(
function_ref<InFlightDiagnostic()> emitError, IntegerAttr descriptorSet,
IntegerAttr binding, IntegerAttr storageClass) {
if (!descriptorSet.getType().isSignlessInteger(32))
return emitError(loc, "expected 32-bit integer for descriptor set");
return emitError() << "expected 32-bit integer for descriptor set";
if (!binding.getType().isSignlessInteger(32))
return emitError(loc, "expected 32-bit integer for binding");
return emitError() << "expected 32-bit integer for binding";
if (storageClass) {
if (auto storageClassAttr = storageClass.cast<IntegerAttr>()) {
auto storageClassValue =
spirv::symbolizeStorageClass(storageClassAttr.getInt());
if (!storageClassValue)
return emitError(loc, "unknown storage class");
return emitError() << "unknown storage class";
} else {
return emitError(loc, "expected valid storage class");
return emitError() << "expected valid storage class";
}
}
@ -257,11 +257,12 @@ ArrayAttr spirv::VerCapExtAttr::getCapabilitiesAttr() {
return getImpl()->capabilities.cast<ArrayAttr>();
}
LogicalResult spirv::VerCapExtAttr::verifyConstructionInvariants(
Location loc, IntegerAttr version, ArrayAttr capabilities,
ArrayAttr extensions) {
LogicalResult
spirv::VerCapExtAttr::verify(function_ref<InFlightDiagnostic()> emitError,
IntegerAttr version, ArrayAttr capabilities,
ArrayAttr extensions) {
if (!version.getType().isSignlessInteger(32))
return emitError(loc, "expected 32-bit integer for version");
return emitError() << "expected 32-bit integer for version";
if (!llvm::all_of(capabilities.getValue(), [](Attribute attr) {
if (auto intAttr = attr.dyn_cast<IntegerAttr>())
@ -269,7 +270,7 @@ LogicalResult spirv::VerCapExtAttr::verifyConstructionInvariants(
return true;
return false;
}))
return emitError(loc, "unknown capability in capability list");
return emitError() << "unknown capability in capability list";
if (!llvm::all_of(extensions.getValue(), [](Attribute attr) {
if (auto strAttr = attr.dyn_cast<StringAttr>())
@ -277,7 +278,7 @@ LogicalResult spirv::VerCapExtAttr::verifyConstructionInvariants(
return true;
return false;
}))
return emitError(loc, "unknown extension in extension list");
return emitError() << "unknown extension in extension list";
return success();
}
@ -338,12 +339,14 @@ spirv::ResourceLimitsAttr spirv::TargetEnvAttr::getResourceLimits() const {
return getImpl()->limits.cast<spirv::ResourceLimitsAttr>();
}
LogicalResult spirv::TargetEnvAttr::verifyConstructionInvariants(
Location loc, spirv::VerCapExtAttr /*triple*/, spirv::Vendor /*vendorID*/,
spirv::DeviceType /*deviceType*/, uint32_t /*deviceID*/,
DictionaryAttr limits) {
LogicalResult
spirv::TargetEnvAttr::verify(function_ref<InFlightDiagnostic()> emitError,
spirv::VerCapExtAttr /*triple*/,
spirv::Vendor /*vendorID*/,
spirv::DeviceType /*deviceType*/,
uint32_t /*deviceID*/, DictionaryAttr limits) {
if (!limits.isa<spirv::ResourceLimitsAttr>())
return emitError(loc, "expected spirv::ResourceLimitsAttr for limits");
return emitError() << "expected spirv::ResourceLimitsAttr for limits";
return success();
}

View File

@ -260,42 +260,33 @@ void CooperativeMatrixNVType::getCapabilities(
// ImageType
//===----------------------------------------------------------------------===//
template <typename T>
static constexpr unsigned getNumBits() {
return 0;
}
template <>
constexpr unsigned getNumBits<Dim>() {
template <typename T> static constexpr unsigned getNumBits() { return 0; }
template <> constexpr unsigned getNumBits<Dim>() {
static_assert((1 << 3) > getMaxEnumValForDim(),
"Not enough bits to encode Dim value");
return 3;
}
template <>
constexpr unsigned getNumBits<ImageDepthInfo>() {
template <> constexpr unsigned getNumBits<ImageDepthInfo>() {
static_assert((1 << 2) > getMaxEnumValForImageDepthInfo(),
"Not enough bits to encode ImageDepthInfo value");
return 2;
}
template <>
constexpr unsigned getNumBits<ImageArrayedInfo>() {
template <> constexpr unsigned getNumBits<ImageArrayedInfo>() {
static_assert((1 << 1) > getMaxEnumValForImageArrayedInfo(),
"Not enough bits to encode ImageArrayedInfo value");
return 1;
}
template <>
constexpr unsigned getNumBits<ImageSamplingInfo>() {
template <> constexpr unsigned getNumBits<ImageSamplingInfo>() {
static_assert((1 << 1) > getMaxEnumValForImageSamplingInfo(),
"Not enough bits to encode ImageSamplingInfo value");
return 1;
}
template <>
constexpr unsigned getNumBits<ImageSamplerUseInfo>() {
template <> constexpr unsigned getNumBits<ImageSamplerUseInfo>() {
static_assert((1 << 2) > getMaxEnumValForImageSamplerUseInfo(),
"Not enough bits to encode ImageSamplerUseInfo value");
return 2;
}
template <>
constexpr unsigned getNumBits<ImageFormat>() {
template <> constexpr unsigned getNumBits<ImageFormat>() {
static_assert((1 << 6) > getMaxEnumValForImageFormat(),
"Not enough bits to encode ImageFormat value");
return 6;
@ -730,17 +721,19 @@ SampledImageType SampledImageType::get(Type imageType) {
return Base::get(imageType.getContext(), imageType);
}
SampledImageType SampledImageType::getChecked(Type imageType,
Location location) {
return Base::getChecked(location, imageType);
SampledImageType
SampledImageType::getChecked(function_ref<InFlightDiagnostic()> emitError,
Type imageType) {
return Base::getChecked(emitError, imageType.getContext(), imageType);
}
Type SampledImageType::getImageType() const { return getImpl()->imageType; }
LogicalResult SampledImageType::verifyConstructionInvariants(Location loc,
Type imageType) {
LogicalResult
SampledImageType::verify(function_ref<InFlightDiagnostic()> emitError,
Type imageType) {
if (!imageType.isa<ImageType>())
return emitError(loc, "expected image type");
return emitError() << "expected image type";
return success();
}
@ -1095,27 +1088,27 @@ MatrixType MatrixType::get(Type columnType, uint32_t columnCount) {
return Base::get(columnType.getContext(), columnType, columnCount);
}
MatrixType MatrixType::getChecked(Type columnType, uint32_t columnCount,
Location location) {
return Base::getChecked(location, columnType, columnCount);
MatrixType MatrixType::getChecked(function_ref<InFlightDiagnostic()> emitError,
Type columnType, uint32_t columnCount) {
return Base::getChecked(emitError, columnType.getContext(), columnType,
columnCount);
}
LogicalResult MatrixType::verifyConstructionInvariants(Location loc,
Type columnType,
uint32_t columnCount) {
LogicalResult MatrixType::verify(function_ref<InFlightDiagnostic()> emitError,
Type columnType, uint32_t columnCount) {
if (columnCount < 2 || columnCount > 4)
return emitError(loc, "matrix can have 2, 3, or 4 columns only");
return emitError() << "matrix can have 2, 3, or 4 columns only";
if (!isValidColumnType(columnType))
return emitError(loc, "matrix columns must be vectors of floats");
return emitError() << "matrix columns must be vectors of floats";
/// The underlying vectors (columns) must be of size 2, 3, or 4
ArrayRef<int64_t> columnShape = columnType.cast<VectorType>().getShape();
if (columnShape.size() != 1)
return emitError(loc, "matrix columns must be 1D vectors");
return emitError() << "matrix columns must be 1D vectors";
if (columnShape[0] < 2 || columnShape[0] > 4)
return emitError(loc, "matrix columns must be of size 2, 3, or 4");
return emitError() << "matrix columns must be of size 2, 3, or 4";
return success();
}

View File

@ -211,16 +211,18 @@ FloatAttr FloatAttr::get(Type type, double value) {
return Base::get(type.getContext(), type, value);
}
FloatAttr FloatAttr::getChecked(Type type, double value, Location loc) {
return Base::getChecked(loc, type, value);
FloatAttr FloatAttr::getChecked(function_ref<InFlightDiagnostic()> emitError,
Type type, double value) {
return Base::getChecked(emitError, type.getContext(), type, value);
}
FloatAttr FloatAttr::get(Type type, const APFloat &value) {
return Base::get(type.getContext(), type, value);
}
FloatAttr FloatAttr::getChecked(Type type, const APFloat &value, Location loc) {
return Base::getChecked(loc, type, value);
FloatAttr FloatAttr::getChecked(function_ref<InFlightDiagnostic()> emitError,
Type type, const APFloat &value) {
return Base::getChecked(emitError, type.getContext(), type, value);
}
APFloat FloatAttr::getValue() const { return getImpl()->getValue(); }
@ -238,27 +240,29 @@ double FloatAttr::getValueAsDouble(APFloat value) {
}
/// Verify construction invariants.
static LogicalResult verifyFloatTypeInvariants(Location loc, Type type) {
static LogicalResult
verifyFloatTypeInvariants(function_ref<InFlightDiagnostic()> emitError,
Type type) {
if (!type.isa<FloatType>())
return emitError(loc, "expected floating point type");
return emitError() << "expected floating point type";
return success();
}
LogicalResult FloatAttr::verifyConstructionInvariants(Location loc, Type type,
double value) {
return verifyFloatTypeInvariants(loc, type);
LogicalResult FloatAttr::verify(function_ref<InFlightDiagnostic()> emitError,
Type type, double value) {
return verifyFloatTypeInvariants(emitError, type);
}
LogicalResult FloatAttr::verifyConstructionInvariants(Location loc, Type type,
const APFloat &value) {
LogicalResult FloatAttr::verify(function_ref<InFlightDiagnostic()> emitError,
Type type, const APFloat &value) {
// Verify that the type is correct.
if (failed(verifyFloatTypeInvariants(loc, type)))
if (failed(verifyFloatTypeInvariants(emitError, type)))
return failure();
// Verify that the type semantics match that of the value.
if (&type.cast<FloatType>().getFloatSemantics() != &value.getSemantics()) {
return emitError(
loc, "FloatAttr type doesn't match the type implied by its value");
return emitError()
<< "FloatAttr type doesn't match the type implied by its value";
}
return success();
}
@ -326,26 +330,28 @@ uint64_t IntegerAttr::getUInt() const {
return getValue().getZExtValue();
}
static LogicalResult verifyIntegerTypeInvariants(Location loc, Type type) {
static LogicalResult
verifyIntegerTypeInvariants(function_ref<InFlightDiagnostic()> emitError,
Type type) {
if (type.isa<IntegerType, IndexType>())
return success();
return emitError(loc, "expected integer or index type");
return emitError() << "expected integer or index type";
}
LogicalResult IntegerAttr::verifyConstructionInvariants(Location loc, Type type,
int64_t value) {
return verifyIntegerTypeInvariants(loc, type);
LogicalResult IntegerAttr::verify(function_ref<InFlightDiagnostic()> emitError,
Type type, int64_t value) {
return verifyIntegerTypeInvariants(emitError, type);
}
LogicalResult IntegerAttr::verifyConstructionInvariants(Location loc, Type type,
const APInt &value) {
if (failed(verifyIntegerTypeInvariants(loc, type)))
LogicalResult IntegerAttr::verify(function_ref<InFlightDiagnostic()> emitError,
Type type, const APInt &value) {
if (failed(verifyIntegerTypeInvariants(emitError, type)))
return failure();
if (auto integerType = type.dyn_cast<IntegerType>())
if (integerType.getWidth() != value.getBitWidth())
return emitError(loc, "integer type bit width (")
<< integerType.getWidth() << ") doesn't match value bit width ("
<< value.getBitWidth() << ")";
return emitError() << "integer type bit width (" << integerType.getWidth()
<< ") doesn't match value bit width ("
<< value.getBitWidth() << ")";
return success();
}
@ -381,9 +387,11 @@ OpaqueAttr OpaqueAttr::get(MLIRContext *context, Identifier dialect,
return Base::get(context, dialect, attrData, type);
}
OpaqueAttr OpaqueAttr::getChecked(Identifier dialect, StringRef attrData,
Type type, Location location) {
return Base::getChecked(location, dialect, attrData, type);
OpaqueAttr OpaqueAttr::getChecked(function_ref<InFlightDiagnostic()> emitError,
Identifier dialect, StringRef attrData,
Type type) {
return Base::getChecked(emitError, dialect.getContext(), dialect, attrData,
type);
}
/// Returns the dialect namespace of the opaque attribute.
@ -395,12 +403,11 @@ Identifier OpaqueAttr::getDialectNamespace() const {
StringRef OpaqueAttr::getAttrData() const { return getImpl()->attrData; }
/// Verify the construction of an opaque attribute.
LogicalResult OpaqueAttr::verifyConstructionInvariants(Location loc,
Identifier dialect,
StringRef attrData,
Type type) {
LogicalResult OpaqueAttr::verify(function_ref<InFlightDiagnostic()> emitError,
Identifier dialect, StringRef attrData,
Type type) {
if (!Dialect::isValidNamespace(dialect.strref()))
return emitError(loc, "invalid dialect namespace '") << dialect << "'";
return emitError() << "invalid dialect namespace '" << dialect << "'";
return success();
}

View File

@ -32,10 +32,10 @@ using namespace mlir::detail;
//===----------------------------------------------------------------------===//
/// Verify the construction of an integer type.
LogicalResult ComplexType::verifyConstructionInvariants(Location loc,
Type elementType) {
LogicalResult ComplexType::verify(function_ref<InFlightDiagnostic()> emitError,
Type elementType) {
if (!elementType.isIntOrFloat())
return emitError(loc, "invalid element type for complex");
return emitError() << "invalid element type for complex";
return success();
}
@ -47,12 +47,12 @@ LogicalResult ComplexType::verifyConstructionInvariants(Location loc,
constexpr unsigned IntegerType::kMaxWidth;
/// Verify the construction of an integer type.
LogicalResult
IntegerType::verifyConstructionInvariants(Location loc, unsigned width,
SignednessSemantics signedness) {
LogicalResult IntegerType::verify(function_ref<InFlightDiagnostic()> emitError,
unsigned width,
SignednessSemantics signedness) {
if (width > IntegerType::kMaxWidth) {
return emitError(loc) << "integer bitwidth is limited to "
<< IntegerType::kMaxWidth << " bits";
return emitError() << "integer bitwidth is limited to "
<< IntegerType::kMaxWidth << " bits";
}
return success();
}
@ -183,11 +183,10 @@ FunctionType::getWithoutArgsAndResults(ArrayRef<unsigned> argIndices,
//===----------------------------------------------------------------------===//
/// Verify the construction of an opaque type.
LogicalResult OpaqueType::verifyConstructionInvariants(Location loc,
Identifier dialect,
StringRef typeData) {
LogicalResult OpaqueType::verify(function_ref<InFlightDiagnostic()> emitError,
Identifier dialect, StringRef typeData) {
if (!Dialect::isValidNamespace(dialect.strref()))
return emitError(loc, "invalid dialect namespace '") << dialect << "'";
return emitError() << "invalid dialect namespace '" << dialect << "'";
return success();
}
@ -362,22 +361,22 @@ VectorType VectorType::get(ArrayRef<int64_t> shape, Type elementType) {
return Base::get(elementType.getContext(), shape, elementType);
}
VectorType VectorType::getChecked(Location location, ArrayRef<int64_t> shape,
Type elementType) {
return Base::getChecked(location, shape, elementType);
VectorType VectorType::getChecked(function_ref<InFlightDiagnostic()> emitError,
ArrayRef<int64_t> shape, Type elementType) {
return Base::getChecked(emitError, elementType.getContext(), shape,
elementType);
}
LogicalResult VectorType::verifyConstructionInvariants(Location loc,
ArrayRef<int64_t> shape,
Type elementType) {
LogicalResult VectorType::verify(function_ref<InFlightDiagnostic()> emitError,
ArrayRef<int64_t> shape, Type elementType) {
if (shape.empty())
return emitError(loc, "vector types must have at least one dimension");
return emitError() << "vector types must have at least one dimension";
if (!isValidElementType(elementType))
return emitError(loc, "vector elements must be int or float type");
return emitError() << "vector elements must be int or float type";
if (any_of(shape, [](int64_t i) { return i <= 0; }))
return emitError(loc, "vector types must have positive constant sizes");
return emitError() << "vector types must have positive constant sizes";
return success();
}
@ -400,12 +399,12 @@ VectorType VectorType::scaleElementBitwidth(unsigned scale) {
// TensorType
//===----------------------------------------------------------------------===//
// Check if "elementType" can be an element type of a tensor. Emit errors if
// location is not nullptr. Returns failure if check failed.
static LogicalResult checkTensorElementType(Location location,
Type elementType) {
// Check if "elementType" can be an element type of a tensor.
static LogicalResult
checkTensorElementType(function_ref<InFlightDiagnostic()> emitError,
Type elementType) {
if (!TensorType::isValidElementType(elementType))
return emitError(location, "invalid tensor element type: ") << elementType;
return emitError() << "invalid tensor element type: " << elementType;
return success();
}
@ -428,19 +427,21 @@ RankedTensorType RankedTensorType::get(ArrayRef<int64_t> shape,
return Base::get(elementType.getContext(), shape, elementType);
}
RankedTensorType RankedTensorType::getChecked(Location location,
ArrayRef<int64_t> shape,
Type elementType) {
return Base::getChecked(location, shape, elementType);
RankedTensorType
RankedTensorType::getChecked(function_ref<InFlightDiagnostic()> emitError,
ArrayRef<int64_t> shape, Type elementType) {
return Base::getChecked(emitError, elementType.getContext(), shape,
elementType);
}
LogicalResult RankedTensorType::verifyConstructionInvariants(
Location loc, ArrayRef<int64_t> shape, Type elementType) {
LogicalResult
RankedTensorType::verify(function_ref<InFlightDiagnostic()> emitError,
ArrayRef<int64_t> shape, Type elementType) {
for (int64_t s : shape) {
if (s < -1)
return emitError(loc, "invalid tensor dimension size");
return emitError() << "invalid tensor dimension size";
}
return checkTensorElementType(loc, elementType);
return checkTensorElementType(emitError, elementType);
}
ArrayRef<int64_t> RankedTensorType::getShape() const {
@ -455,15 +456,16 @@ UnrankedTensorType UnrankedTensorType::get(Type elementType) {
return Base::get(elementType.getContext(), elementType);
}
UnrankedTensorType UnrankedTensorType::getChecked(Location location,
Type elementType) {
return Base::getChecked(location, elementType);
UnrankedTensorType
UnrankedTensorType::getChecked(function_ref<InFlightDiagnostic()> emitError,
Type elementType) {
return Base::getChecked(emitError, elementType.getContext(), elementType);
}
LogicalResult
UnrankedTensorType::verifyConstructionInvariants(Location loc,
Type elementType) {
return checkTensorElementType(loc, elementType);
UnrankedTensorType::verify(function_ref<InFlightDiagnostic()> emitError,
Type elementType) {
return checkTensorElementType(emitError, elementType);
}
//===----------------------------------------------------------------------===//
@ -485,8 +487,10 @@ unsigned BaseMemRefType::getMemorySpace() const {
MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
ArrayRef<AffineMap> affineMapComposition,
unsigned memorySpace) {
auto result = getImpl(shape, elementType, affineMapComposition, memorySpace,
/*location=*/llvm::None);
auto result =
getImpl(shape, elementType, affineMapComposition, memorySpace, [=] {
return emitError(UnknownLoc::get(elementType.getContext()));
});
assert(result && "Failed to construct instance of MemRefType.");
return result;
}
@ -497,12 +501,12 @@ MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
/// UnknownLoc. If the MemRefType defined by the arguments would be
/// ill-formed, emits errors (to the handler registered with the context or to
/// the error stream) and returns nullptr.
MemRefType MemRefType::getChecked(Location location, ArrayRef<int64_t> shape,
Type elementType,
MemRefType MemRefType::getChecked(function_ref<InFlightDiagnostic()> emitError,
ArrayRef<int64_t> shape, Type elementType,
ArrayRef<AffineMap> affineMapComposition,
unsigned memorySpace) {
return getImpl(shape, elementType, affineMapComposition, memorySpace,
location);
emitError);
}
/// Get or create a new MemRefType defined by the arguments. If the resulting
@ -512,18 +516,16 @@ MemRefType MemRefType::getChecked(Location location, ArrayRef<int64_t> shape,
MemRefType MemRefType::getImpl(ArrayRef<int64_t> shape, Type elementType,
ArrayRef<AffineMap> affineMapComposition,
unsigned memorySpace,
Optional<Location> location) {
function_ref<InFlightDiagnostic()> emitError) {
auto *context = elementType.getContext();
if (!BaseMemRefType::isValidElementType(elementType))
return (void)emitOptionalError(location, "invalid memref element type"),
MemRefType();
return (emitError() << "invalid memref element type", MemRefType());
for (int64_t s : shape) {
// Negative sizes are not allowed except for `-1` that means dynamic size.
if (s < -1)
return (void)emitOptionalError(location, "invalid memref size"),
MemRefType();
return (emitError() << "invalid memref size", MemRefType());
}
// Check that the structure of the composition is valid, i.e. that each
@ -533,12 +535,10 @@ MemRefType MemRefType::getImpl(ArrayRef<int64_t> shape, Type elementType,
unsigned i = 0;
for (const auto &affineMap : affineMapComposition) {
if (affineMap.getNumDims() != dim) {
if (location)
emitError(*location)
<< "memref affine map dimension mismatch between "
<< (i == 0 ? Twine("memref rank") : "affine map " + Twine(i))
<< " and affine map" << i + 1 << ": " << dim
<< " != " << affineMap.getNumDims();
emitError() << "memref affine map dimension mismatch between "
<< (i == 0 ? Twine("memref rank") : "affine map " + Twine(i))
<< " and affine map" << i + 1 << ": " << dim
<< " != " << affineMap.getNumDims();
return nullptr;
}
@ -575,17 +575,18 @@ UnrankedMemRefType UnrankedMemRefType::get(Type elementType,
return Base::get(elementType.getContext(), elementType, memorySpace);
}
UnrankedMemRefType UnrankedMemRefType::getChecked(Location location,
Type elementType,
unsigned memorySpace) {
return Base::getChecked(location, elementType, memorySpace);
UnrankedMemRefType
UnrankedMemRefType::getChecked(function_ref<InFlightDiagnostic()> emitError,
Type elementType, unsigned memorySpace) {
return Base::getChecked(emitError, elementType.getContext(), elementType,
memorySpace);
}
LogicalResult
UnrankedMemRefType::verifyConstructionInvariants(Location loc, Type elementType,
unsigned memorySpace) {
UnrankedMemRefType::verify(function_ref<InFlightDiagnostic()> emitError,
Type elementType, unsigned memorySpace) {
if (!BaseMemRefType::isValidElementType(elementType))
return emitError(loc, "invalid memref element type");
return emitError() << "invalid memref element type";
return success();
}

View File

@ -856,12 +856,13 @@ IntegerType IntegerType::get(MLIRContext *context, unsigned width,
return Base::get(context, width, signedness);
}
IntegerType IntegerType::getChecked(Location location, unsigned width,
SignednessSemantics signedness) {
if (auto cached =
getCachedIntegerType(width, signedness, location->getContext()))
IntegerType
IntegerType::getChecked(function_ref<InFlightDiagnostic()> emitError,
MLIRContext *context, unsigned width,
SignednessSemantics signedness) {
if (auto cached = getCachedIntegerType(width, signedness, context))
return cached;
return Base::getChecked(location, width, signedness);
return Base::getChecked(emitError, context, width, signedness);
}
/// Get an instance of the NoneType.
@ -1005,11 +1006,14 @@ IntegerSet IntegerSet::get(unsigned dimCount, unsigned symbolCount,
// StorageUniquerSupport
//===----------------------------------------------------------------------===//
/// Utility method to generate a default location for use when checking the
/// construction invariants of a storage object. This is defined out-of-line to
/// avoid the need to include Location.h.
const AttributeStorage *
mlir::detail::generateUnknownStorageLocation(MLIRContext *ctx) {
return reinterpret_cast<const AttributeStorage *>(
ctx->getImpl().unknownLocAttr.getAsOpaquePointer());
/// Utility method to generate a callback that can be used to generate a
/// diagnostic when checking the construction invariants of a storage object.
/// This is defined out-of-line to avoid the need to include Location.h.
llvm::unique_function<InFlightDiagnostic()>
mlir::detail::getDefaultDiagnosticEmitFn(MLIRContext *ctx) {
return [ctx] { return emitError(UnknownLoc::get(ctx)); };
}
llvm::unique_function<InFlightDiagnostic()>
mlir::detail::getDefaultDiagnosticEmitFn(const Location &loc) {
return [=] { return emitError(loc); };
}

View File

@ -524,9 +524,9 @@ Attribute Parser::parseExtendedAttr(Type type) {
// Otherwise, form a new opaque attribute.
return OpaqueAttr::getChecked(
getEncodedSourceLocation(loc),
Identifier::get(dialectName, state.context), symbolData,
attrType ? attrType : NoneType::get(state.context),
getEncodedSourceLocation(loc));
attrType ? attrType : NoneType::get(state.context));
});
// Ensure that the attribute has the same type as requested.
@ -563,7 +563,7 @@ Type Parser::parseExtendedType() {
// Otherwise, form a new opaque type.
return OpaqueType::getChecked(
getEncodedSourceLocation(loc),
getEncodedSourceLocation(loc), state.context,
Identifier::get(dialectName, state.context), symbolData);
});
}

View File

@ -23,13 +23,6 @@ using namespace mlir::tblgen;
// TypeBuilder
//===----------------------------------------------------------------------===//
/// Return an optional code body used for the `getChecked` variant of this
/// builder.
Optional<StringRef> TypeBuilder::getCheckedBody() const {
Optional<StringRef> body = def->getValueAsOptionalString("checkedBody");
return body && !body->empty() ? body : llvm::None;
}
/// Returns true if this builder is able to infer the MLIRContext parameter.
bool TypeBuilder::hasInferredContextParameter() const {
return def->getValueAsBit("hasInferredContextParam");
@ -111,8 +104,8 @@ llvm::Optional<StringRef> TypeDef::getParserCode() const {
bool TypeDef::genAccessors() const {
return def->getValueAsBit("genAccessors");
}
bool TypeDef::genVerifyInvariantsDecl() const {
return def->getValueAsBit("genVerifyInvariantsDecl");
bool TypeDef::genVerifyDecl() const {
return def->getValueAsBit("genVerifyDecl");
}
llvm::Optional<StringRef> TypeDef::getExtraDecls() const {
auto value = def->getValueAsString("extraClassDeclaration");

View File

@ -48,7 +48,7 @@ def CompoundTypeA : Test_Type<"CompoundA"> {
// An example of how one could implement a standard integer.
def IntegerType : Test_Type<"TestInteger"> {
let mnemonic = "int";
let genVerifyInvariantsDecl = 1;
let genVerifyDecl = 1;
let parameters = (
ins
"unsigned":$width,
@ -67,9 +67,7 @@ def IntegerType : Test_Type<"TestInteger"> {
let builders = [
TypeBuilder<(ins "unsigned":$width,
CArg<"SignednessSemantics", "Signless">:$signedness), [{
return Base::get($_ctxt, width, signedness);
}], [{
return Base::getChecked($_loc, width, signedness);
return $_get($_ctxt, width, signedness);
}]>
];
let skipDefaultBuilders = 1;
@ -84,7 +82,7 @@ def IntegerType : Test_Type<"TestInteger"> {
if ($_parser.parseInteger(width)) return Type();
if ($_parser.parseGreater()) return Type();
Location loc = $_parser.getEncodedSourceLoc($_parser.getNameLoc());
return getChecked(loc, width, signedness);
return getChecked(loc, loc.getContext(), width, signedness);
}];
// Any extra code one wants in the type's class declaration.

View File

@ -112,8 +112,10 @@ static llvm::hash_code mlir::test::hash_value(const FieldInfo &fi) { // NOLINT
}
// Example type validity checker.
LogicalResult TestIntegerType::verifyConstructionInvariants(
Location loc, unsigned width, TestIntegerType::SignednessSemantics ss) {
LogicalResult
TestIntegerType::verify(function_ref<InFlightDiagnostic()> emitError,
unsigned width,
TestIntegerType::SignednessSemantics ss) {
if (width > 8)
return failure();
return success();

View File

@ -54,11 +54,11 @@ def B_CompoundTypeA : TestType<"CompoundA"> {
RTLValueType:$inner
);
let genVerifyInvariantsDecl = 1;
let genVerifyDecl = 1;
// DECL-LABEL: class CompoundAType : public ::mlir::Type
// DECL: static CompoundAType getChecked(::mlir::Location loc, int widthOfSomething, ::mlir::test::SimpleTypeA exampleTdType, SomeCppStruct exampleCppType, ::llvm::ArrayRef<int> dims, ::mlir::Type inner);
// DECL: static ::mlir::LogicalResult verifyConstructionInvariants(::mlir::Location loc, int widthOfSomething, ::mlir::test::SimpleTypeA exampleTdType, SomeCppStruct exampleCppType, ::llvm::ArrayRef<int> dims, ::mlir::Type inner);
// DECL: static CompoundAType getChecked(llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, ::mlir::MLIRContext *context, int widthOfSomething, ::mlir::test::SimpleTypeA exampleTdType, SomeCppStruct exampleCppType, ::llvm::ArrayRef<int> dims, ::mlir::Type inner);
// DECL: static ::mlir::LogicalResult verify(::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, int widthOfSomething, ::mlir::test::SimpleTypeA exampleTdType, SomeCppStruct exampleCppType, ::llvm::ArrayRef<int> dims, ::mlir::Type inner);
// DECL: static ::llvm::StringRef getMnemonic() { return "cmpnd_a"; }
// DECL: static ::mlir::Type parse(::mlir::MLIRContext *context,
// DECL-NEXT: ::mlir::DialectAsmParser &parser);
@ -95,7 +95,7 @@ def D_SingleParameterType : TestType<"SingleParameter"> {
def E_IntegerType : TestType<"Integer"> {
let mnemonic = "int";
let genVerifyInvariantsDecl = 1;
let genVerifyDecl = 1;
let parameters = (
ins
"SignednessSemantics":$signedness,

View File

@ -182,27 +182,29 @@ static const char *const typeDefParsePrint = R"(
void print(::mlir::DialectAsmPrinter &printer) const;
)";
/// The code block for the verifyConstructionInvariants and getChecked.
/// The code block for the verify method declaration.
///
/// {0}: The name of the typeDef class.
/// {1}: List of parameters, parameters style.
/// {0}: List of parameters, parameters style.
static const char *const typeDefDeclVerifyStr = R"(
static ::mlir::LogicalResult verifyConstructionInvariants(::mlir::Location loc{1});
using Base::getChecked;
static ::mlir::LogicalResult verify(::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError{0});
)";
/// Emit the builders for the given type.
static void emitTypeBuilderDecls(const TypeDef &typeDef, raw_ostream &os,
TypeParamCommaFormatter &paramTypes) {
StringRef typeClass = typeDef.getCppClassName();
bool genCheckedMethods = typeDef.genVerifyInvariantsDecl();
bool genCheckedMethods = typeDef.genVerifyDecl();
if (!typeDef.skipDefaultBuilders()) {
os << llvm::formatv(
" static {0} get(::mlir::MLIRContext *context{1});\n", typeClass,
paramTypes);
if (genCheckedMethods) {
os << llvm::formatv(
" static {0} getChecked(::mlir::Location loc{1});\n", typeClass,
paramTypes);
os << llvm::formatv(" static {0} "
"getChecked(llvm::function_ref<::mlir::"
"InFlightDiagnostic()> emitError, "
"::mlir::MLIRContext *context{1});\n",
typeClass, paramTypes);
}
}
@ -231,10 +233,14 @@ static void emitTypeBuilderDecls(const TypeDef &typeDef, raw_ostream &os,
// Generate the `getChecked` variant of the builder.
if (genCheckedMethods) {
os << " static " << typeClass << " getChecked(::mlir::Location loc";
os << " static " << typeClass
<< " getChecked(llvm::function_ref<mlir::InFlightDiagnostic()> "
"emitError";
if (!builder.hasInferredContextParameter())
os << ", ::mlir::MLIRContext *context";
if (!paramStr.empty())
os << ", " << paramStr;
os << ");\n";
os << ", ";
os << paramStr << ");\n";
}
}
}
@ -265,9 +271,8 @@ static void emitTypeDefDecl(const TypeDef &typeDef, raw_ostream &os) {
emitTypeBuilderDecls(typeDef, os, emitTypeNamePairsAfterComma);
// Emit the verify invariants declaration.
if (typeDef.genVerifyInvariantsDecl())
os << llvm::formatv(typeDefDeclVerifyStr, typeDef.getCppClassName(),
emitTypeNamePairsAfterComma);
if (typeDef.genVerifyDecl())
os << llvm::formatv(typeDefDeclVerifyStr, emitTypeNamePairsAfterComma);
}
// Emit the mnenomic, if specified.
@ -515,10 +520,18 @@ void emitParserPrinter(TypeDef typeDef, raw_ostream &os) {
}
}
/// Replace all instances of 'from' to 'to' in `str` and return the new string.
static std::string replaceInStr(std::string str, StringRef from, StringRef to) {
size_t pos = 0;
while ((pos = str.find(from.data(), pos, from.size())) != std::string::npos)
str.replace(pos, from.size(), to.data(), to.size());
return str;
}
/// Emit the builders for the given type.
static void emitTypeBuilderDefs(const TypeDef &typeDef, raw_ostream &os,
ArrayRef<TypeParameter> typeDefParams) {
bool genCheckedMethods = typeDef.genVerifyInvariantsDecl();
bool genCheckedMethods = typeDef.genVerifyDecl();
StringRef typeClass = typeDef.getCppClassName();
if (!typeDef.skipDefaultBuilders()) {
os << llvm::formatv(
@ -531,8 +544,10 @@ static void emitTypeBuilderDefs(const TypeDef &typeDef, raw_ostream &os,
typeDefParams));
if (genCheckedMethods) {
os << llvm::formatv(
"{0} {0}::getChecked(::mlir::Location loc{1}) {{\n"
" return Base::getChecked(loc{2});\n}\n",
"{0} {0}::getChecked("
"llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, "
"::mlir::MLIRContext *context{1}) {{\n"
" return Base::getChecked(emitError, context{2});\n}\n",
typeClass,
TypeParamCommaFormatter(
TypeParamCommaFormatter::EmitFormat::TypeNamePairs,
@ -542,16 +557,15 @@ static void emitTypeBuilderDefs(const TypeDef &typeDef, raw_ostream &os,
}
}
auto builderFmtCtx =
FmtContext().addSubst("_ctxt", "context").addSubst("_get", "Base::get");
auto inferredCtxBuilderFmtCtx = FmtContext().addSubst("_get", "Base::get");
auto checkedBuilderFmtCtx = FmtContext().addSubst("_ctxt", "context");
// Generate the builders specified by the user.
auto builderFmtCtx = FmtContext().addSubst("_ctxt", "context");
auto checkedBuilderFmtCtx = FmtContext()
.addSubst("_loc", "loc")
.addSubst("_ctxt", "loc.getContext()");
for (const TypeBuilder &builder : typeDef.getBuilders()) {
Optional<StringRef> body = builder.getBody();
Optional<StringRef> checkedBody =
genCheckedMethods ? builder.getCheckedBody() : llvm::None;
if (!body && !checkedBody)
if (!body)
continue;
std::string paramStr;
llvm::raw_string_ostream paramOS(paramStr);
@ -565,27 +579,33 @@ static void emitTypeBuilderDefs(const TypeDef &typeDef, raw_ostream &os,
paramOS.flush();
// Emit the `get` variant of the builder.
if (body) {
os << llvm::formatv("{0} {0}::get(", typeClass);
if (!builder.hasInferredContextParameter()) {
os << "::mlir::MLIRContext *context";
if (!paramStr.empty())
os << ", ";
os << llvm::formatv("{0}) {{\n {1};\n}\n", paramStr,
tgfmt(*body, &builderFmtCtx).str());
} else {
os << llvm::formatv("{0}) {{\n {1};\n}\n", paramStr, *body);
}
os << llvm::formatv("{0} {0}::get(", typeClass);
if (!builder.hasInferredContextParameter()) {
os << "::mlir::MLIRContext *context";
if (!paramStr.empty())
os << ", ";
os << llvm::formatv("{0}) {{\n {1};\n}\n", paramStr,
tgfmt(*body, &builderFmtCtx).str());
} else {
os << llvm::formatv("{0}) {{\n {1};\n}\n", paramStr,
tgfmt(*body, &inferredCtxBuilderFmtCtx).str());
}
// Emit the `getChecked` variant of the builder.
if (checkedBody) {
os << llvm::formatv("{0} {0}::getChecked(::mlir::Location loc",
if (genCheckedMethods) {
os << llvm::formatv("{0} "
"{0}::getChecked(llvm::function_ref<::mlir::"
"InFlightDiagnostic()> emitErrorFn",
typeClass);
std::string checkedBody =
replaceInStr(body->str(), "$_get(", "Base::getChecked(emitErrorFn, ");
if (!builder.hasInferredContextParameter()) {
os << ", ::mlir::MLIRContext *context";
checkedBody = tgfmt(checkedBody, &checkedBuilderFmtCtx).str();
}
if (!paramStr.empty())
os << ", " << paramStr;
os << llvm::formatv(") {{\n {0};\n}\n",
tgfmt(*checkedBody, &checkedBuilderFmtCtx));
os << ", ";
os << llvm::formatv("{0}) {{\n {1};\n}\n", paramStr, checkedBody);
}
}
}