[mlir][IR][NFC] Move context/location parameters of builtin Type::get methods to the start of the parameter list
This better matches the rest of the infrastructure, is much simpler, and makes it easier to move these types to being declaratively specified. Differential Revision: https://reviews.llvm.org/D93432
This commit is contained in:
parent
511cfe9441
commit
1b97cdf885
|
@ -2176,7 +2176,7 @@ def fir_DispatchOp : fir_Op<"dispatch",
|
|||
p.printOptionalAttrDict(getAttrs(), {"fn_type", "method"});
|
||||
auto resTy{getResultTypes()};
|
||||
llvm::SmallVector<mlir::Type, 8> argTy(getOperandTypes());
|
||||
p << " : " << mlir::FunctionType::get(argTy, resTy, getContext());
|
||||
p << " : " << mlir::FunctionType::get(getContext(), argTy, resTy);
|
||||
}];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
|
|
|
@ -49,7 +49,7 @@ mlir::Type genFIRType(mlir::MLIRContext *context) {
|
|||
if constexpr (TC == Fortran::common::TypeCategory::Integer) {
|
||||
auto bits{Fortran::evaluate::Type<Fortran::common::TypeCategory::Integer,
|
||||
KIND>::Scalar::bits};
|
||||
return mlir::IntegerType::get(bits, context);
|
||||
return mlir::IntegerType::get(context, bits);
|
||||
} else if constexpr (TC == Fortran::common::TypeCategory::Logical ||
|
||||
TC == Fortran::common::TypeCategory::Character ||
|
||||
TC == Fortran::common::TypeCategory::Complex) {
|
||||
|
@ -278,7 +278,7 @@ private:
|
|||
|
||||
// some sequence of `n` bytes
|
||||
mlir::Type gen(const Fortran::evaluate::StaticDataObject::Pointer &ptr) {
|
||||
mlir::Type byteTy{mlir::IntegerType::get(8, context)};
|
||||
mlir::Type byteTy{mlir::IntegerType::get(context, 8)};
|
||||
return fir::SequenceType::get(trivialShape(ptr->itemBytes()), byteTy);
|
||||
}
|
||||
|
||||
|
|
|
@ -298,26 +298,26 @@ static constexpr RuntimeFunction pgmathPrecise[] = {
|
|||
|
||||
static mlir::FunctionType genF32F32FuncType(mlir::MLIRContext *context) {
|
||||
auto t = mlir::FloatType::getF32(context);
|
||||
return mlir::FunctionType::get({t}, {t}, context);
|
||||
return mlir::FunctionType::get(context, {t}, {t});
|
||||
}
|
||||
|
||||
static mlir::FunctionType genF64F64FuncType(mlir::MLIRContext *context) {
|
||||
auto t = mlir::FloatType::getF64(context);
|
||||
return mlir::FunctionType::get({t}, {t}, context);
|
||||
return mlir::FunctionType::get(context, {t}, {t});
|
||||
}
|
||||
|
||||
template <int Bits>
|
||||
static mlir::FunctionType genIntF64FuncType(mlir::MLIRContext *context) {
|
||||
auto t = mlir::FloatType::getF64(context);
|
||||
auto r = mlir::IntegerType::get(Bits, context);
|
||||
return mlir::FunctionType::get({t}, {r}, context);
|
||||
auto r = mlir::IntegerType::get(context, Bits);
|
||||
return mlir::FunctionType::get(context, {t}, {r});
|
||||
}
|
||||
|
||||
template <int Bits>
|
||||
static mlir::FunctionType genIntF32FuncType(mlir::MLIRContext *context) {
|
||||
auto t = mlir::FloatType::getF32(context);
|
||||
auto r = mlir::IntegerType::get(Bits, context);
|
||||
return mlir::FunctionType::get({t}, {r}, context);
|
||||
auto r = mlir::IntegerType::get(context, Bits);
|
||||
return mlir::FunctionType::get(context, {t}, {r});
|
||||
}
|
||||
|
||||
// TODO : Fill-up this table with more intrinsic.
|
||||
|
@ -585,8 +585,8 @@ getFunctionType(mlir::Type resultType, llvm::ArrayRef<mlir::Value> arguments,
|
|||
llvm::SmallVector<mlir::Type, 2> argumentTypes;
|
||||
for (auto &arg : arguments)
|
||||
argumentTypes.push_back(arg.getType());
|
||||
return mlir::FunctionType::get(argumentTypes, resultType,
|
||||
builder.getModule().getContext());
|
||||
return mlir::FunctionType::get(builder.getModule().getContext(),
|
||||
argumentTypes, resultType);
|
||||
}
|
||||
|
||||
/// fir::ExtendedValue to mlir::Value translation layer
|
||||
|
@ -1144,7 +1144,7 @@ mlir::Value IntrinsicLibrary::genMerge(mlir::Type,
|
|||
llvm::ArrayRef<mlir::Value> args) {
|
||||
assert(args.size() == 3);
|
||||
|
||||
auto i1Type = mlir::IntegerType::get(1, builder.getContext());
|
||||
auto i1Type = mlir::IntegerType::get(builder.getContext(), 1);
|
||||
auto mask = builder.createConvert(loc, i1Type, args[2]);
|
||||
return builder.create<mlir::SelectOp>(loc, mask, args[0], args[1]);
|
||||
}
|
||||
|
|
|
@ -48,7 +48,7 @@ static constexpr TypeBuilderFunc getModel();
|
|||
template <>
|
||||
constexpr TypeBuilderFunc getModel<int>() {
|
||||
return [](mlir::MLIRContext *context) -> mlir::Type {
|
||||
return mlir::IntegerType::get(8 * sizeof(int), context);
|
||||
return mlir::IntegerType::get(context, 8 * sizeof(int));
|
||||
};
|
||||
}
|
||||
template <>
|
||||
|
@ -61,14 +61,14 @@ constexpr TypeBuilderFunc getModel<int &>() {
|
|||
template <>
|
||||
constexpr TypeBuilderFunc getModel<Fortran::runtime::io::Iostat>() {
|
||||
return [](mlir::MLIRContext *context) -> mlir::Type {
|
||||
return mlir::IntegerType::get(8 * sizeof(Fortran::runtime::io::Iostat),
|
||||
context);
|
||||
return mlir::IntegerType::get(context,
|
||||
8 * sizeof(Fortran::runtime::io::Iostat));
|
||||
};
|
||||
}
|
||||
template <>
|
||||
constexpr TypeBuilderFunc getModel<char *>() {
|
||||
return [](mlir::MLIRContext *context) -> mlir::Type {
|
||||
return fir::ReferenceType::get(mlir::IntegerType::get(8, context));
|
||||
return fir::ReferenceType::get(mlir::IntegerType::get(context, 8));
|
||||
};
|
||||
}
|
||||
template <>
|
||||
|
@ -78,26 +78,26 @@ constexpr TypeBuilderFunc getModel<const char *>() {
|
|||
template <>
|
||||
constexpr TypeBuilderFunc getModel<const char16_t *>() {
|
||||
return [](mlir::MLIRContext *context) -> mlir::Type {
|
||||
return fir::ReferenceType::get(mlir::IntegerType::get(16, context));
|
||||
return fir::ReferenceType::get(mlir::IntegerType::get(context, 16));
|
||||
};
|
||||
}
|
||||
template <>
|
||||
constexpr TypeBuilderFunc getModel<const char32_t *>() {
|
||||
return [](mlir::MLIRContext *context) -> mlir::Type {
|
||||
return fir::ReferenceType::get(mlir::IntegerType::get(32, context));
|
||||
return fir::ReferenceType::get(mlir::IntegerType::get(context, 32));
|
||||
};
|
||||
}
|
||||
template <>
|
||||
constexpr TypeBuilderFunc getModel<void **>() {
|
||||
return [](mlir::MLIRContext *context) -> mlir::Type {
|
||||
return fir::ReferenceType::get(
|
||||
fir::PointerType::get(mlir::IntegerType::get(8, context)));
|
||||
fir::PointerType::get(mlir::IntegerType::get(context, 8)));
|
||||
};
|
||||
}
|
||||
template <>
|
||||
constexpr TypeBuilderFunc getModel<std::int64_t>() {
|
||||
return [](mlir::MLIRContext *context) -> mlir::Type {
|
||||
return mlir::IntegerType::get(64, context);
|
||||
return mlir::IntegerType::get(context, 64);
|
||||
};
|
||||
}
|
||||
template <>
|
||||
|
@ -110,7 +110,7 @@ constexpr TypeBuilderFunc getModel<std::int64_t &>() {
|
|||
template <>
|
||||
constexpr TypeBuilderFunc getModel<std::size_t>() {
|
||||
return [](mlir::MLIRContext *context) -> mlir::Type {
|
||||
return mlir::IntegerType::get(8 * sizeof(std::size_t), context);
|
||||
return mlir::IntegerType::get(context, 8 * sizeof(std::size_t));
|
||||
};
|
||||
}
|
||||
template <>
|
||||
|
@ -146,7 +146,7 @@ constexpr TypeBuilderFunc getModel<float &>() {
|
|||
template <>
|
||||
constexpr TypeBuilderFunc getModel<bool>() {
|
||||
return [](mlir::MLIRContext *context) -> mlir::Type {
|
||||
return mlir::IntegerType::get(1, context);
|
||||
return mlir::IntegerType::get(context, 1);
|
||||
};
|
||||
}
|
||||
template <>
|
||||
|
@ -190,7 +190,7 @@ struct RuntimeTableKey<RT(ATs...)> {
|
|||
llvm::SmallVector<mlir::Type, sizeof...(ATs)> argTys;
|
||||
for (auto f : args)
|
||||
argTys.push_back(f(ctxt));
|
||||
return mlir::FunctionType::get(argTys, {retTy}, ctxt);
|
||||
return mlir::FunctionType::get(ctxt, argTys, {retTy});
|
||||
};
|
||||
}
|
||||
};
|
||||
|
|
|
@ -151,7 +151,7 @@ mlir::Type fir::BoxDimsOp::getTupleType() {
|
|||
// note: triple, but 4 is nearest power of 2
|
||||
llvm::SmallVector<mlir::Type, 4> triple{
|
||||
getResult(0).getType(), getResult(1).getType(), getResult(2).getType()};
|
||||
return mlir::TupleType::get(triple, getContext());
|
||||
return mlir::TupleType::get(getContext(), triple);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -171,7 +171,7 @@ static void printCallOp(mlir::OpAsmPrinter &p, fir::CallOp &op) {
|
|||
auto resultTypes{op.getResultTypes()};
|
||||
llvm::SmallVector<Type, 8> argTypes(
|
||||
llvm::drop_begin(op.getOperandTypes(), isDirect ? 0 : 1));
|
||||
p << " : " << FunctionType::get(argTypes, resultTypes, op.getContext());
|
||||
p << " : " << FunctionType::get(op.getContext(), argTypes, resultTypes);
|
||||
}
|
||||
|
||||
static mlir::ParseResult parseCallOp(mlir::OpAsmParser &parser,
|
||||
|
@ -1565,4 +1565,3 @@ fir::GlobalOp fir::createGlobalOp(mlir::Location loc, mlir::ModuleOp module,
|
|||
|
||||
#define GET_OP_CLASSES
|
||||
#include "flang/Optimizer/Dialect/FIROps.cpp.inc"
|
||||
|
||||
|
|
|
@ -35,8 +35,8 @@ def MaskRndScaleOp : AVX512_Op<"mask.rndscale", [NoSideEffect,
|
|||
AllTypesMatch<["src", "a", "dst"]>,
|
||||
TypesMatchWith<"imm has the same number of bits as elements in dst",
|
||||
"dst", "imm",
|
||||
"IntegerType::get(($_self.cast<VectorType>().getShape()[0]),"
|
||||
" $_self.getContext())">]> {
|
||||
"IntegerType::get($_self.getContext(), "
|
||||
"($_self.cast<VectorType>().getShape()[0]))">]> {
|
||||
let summary = "Masked roundscale op";
|
||||
let description = [{
|
||||
The mask.rndscale op is an AVX512 specific op that can lower to the proper
|
||||
|
@ -67,8 +67,8 @@ def MaskScaleFOp : AVX512_Op<"mask.scalef", [NoSideEffect,
|
|||
AllTypesMatch<["src", "a", "b", "dst"]>,
|
||||
TypesMatchWith<"k has the same number of bits as elements in dst",
|
||||
"dst", "k",
|
||||
"IntegerType::get(($_self.cast<VectorType>().getShape()[0]),"
|
||||
" $_self.getContext())">]> {
|
||||
"IntegerType::get($_self.getContext(), "
|
||||
"($_self.cast<VectorType>().getShape()[0]))">]> {
|
||||
let summary = "ScaleF op";
|
||||
let description = [{
|
||||
The `mask.scalef` op is an AVX512 specific op that can lower to the proper
|
||||
|
|
|
@ -911,7 +911,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
|
|||
auto attr = (*this)->getAttr("operand_segment_sizes")
|
||||
.cast<DenseIntElementsAttr>();
|
||||
unsigned i = 0;
|
||||
auto newAttr = attr.mapValues(IntegerType::get(32, getContext()),
|
||||
auto newAttr = attr.mapValues(IntegerType::get(getContext(), 32),
|
||||
[&](const APInt &v) { return (i++ == idx) ? APInt(32, val) : v; });
|
||||
getOperation()->setAttr("operand_segment_sizes", newAttr);
|
||||
}
|
||||
|
|
|
@ -63,7 +63,7 @@ public:
|
|||
/// Get or create a ComplexType with the provided element type. This emits
|
||||
/// and error at the specified location and returns null if the element type
|
||||
/// isn't supported.
|
||||
static ComplexType getChecked(Type elementType, Location location);
|
||||
static ComplexType getChecked(Location location, Type elementType);
|
||||
|
||||
/// Verify the construction of an integer type.
|
||||
static LogicalResult verifyConstructionInvariants(Location loc,
|
||||
|
@ -93,27 +93,27 @@ public:
|
|||
/// The created IntegerType is signless (i.e., no signedness semantics).
|
||||
/// Assume the width is within the allowed range and assert on failures. Use
|
||||
/// getChecked to handle failures gracefully.
|
||||
static IntegerType get(unsigned width, MLIRContext *context);
|
||||
static IntegerType get(MLIRContext *context, unsigned width);
|
||||
|
||||
/// Get or create a new IntegerType of the given width within the context.
|
||||
/// The created IntegerType has signedness semantics as indicated via
|
||||
/// `signedness`. Assume the width is within the allowed range and assert on
|
||||
/// failures. Use getChecked to handle failures gracefully.
|
||||
static IntegerType get(unsigned width, SignednessSemantics signedness,
|
||||
MLIRContext *context);
|
||||
static IntegerType get(MLIRContext *context, unsigned width,
|
||||
SignednessSemantics signedness);
|
||||
|
||||
/// Get or create a new IntegerType of the given width within the context,
|
||||
/// defined at the given, potentially unknown, location. The created
|
||||
/// IntegerType is signless (i.e., no signedness semantics). If the width is
|
||||
/// outside the allowed range, emit errors and return a null type.
|
||||
static IntegerType getChecked(unsigned width, Location location);
|
||||
static IntegerType getChecked(Location location, unsigned width);
|
||||
|
||||
/// Get or create a new IntegerType of the given width within the context,
|
||||
/// defined at the given, potentially unknown, location. The created
|
||||
/// IntegerType has signedness semantics as indicated via `signedness`. If the
|
||||
/// width is outside the allowed range, emit errors and return a null type.
|
||||
static IntegerType getChecked(unsigned width, SignednessSemantics signedness,
|
||||
Location location);
|
||||
static IntegerType getChecked(Location location, unsigned width,
|
||||
SignednessSemantics signedness);
|
||||
|
||||
/// Verify the construction of an integer type.
|
||||
static LogicalResult
|
||||
|
@ -180,8 +180,8 @@ class FunctionType
|
|||
public:
|
||||
using Base::Base;
|
||||
|
||||
static FunctionType get(TypeRange inputs, TypeRange results,
|
||||
MLIRContext *context);
|
||||
static FunctionType get(MLIRContext *context, TypeRange inputs,
|
||||
TypeRange results);
|
||||
|
||||
/// Input types.
|
||||
unsigned getNumInputs() const;
|
||||
|
@ -211,14 +211,14 @@ public:
|
|||
using Base::Base;
|
||||
|
||||
/// Get or create a new OpaqueType with the provided dialect and string data.
|
||||
static OpaqueType get(Identifier dialect, StringRef typeData,
|
||||
MLIRContext *context);
|
||||
static OpaqueType get(MLIRContext *context, Identifier dialect,
|
||||
StringRef typeData);
|
||||
|
||||
/// Get or create a new OpaqueType with the provided dialect and string data.
|
||||
/// If the given identifier is not a valid namespace for a dialect, then a
|
||||
/// null type is returned.
|
||||
static OpaqueType getChecked(Identifier dialect, StringRef typeData,
|
||||
MLIRContext *context, Location location);
|
||||
static OpaqueType getChecked(Location location, Identifier dialect,
|
||||
StringRef typeData);
|
||||
|
||||
/// Returns the dialect namespace of the opaque type.
|
||||
Identifier getDialectNamespace() const;
|
||||
|
@ -335,8 +335,8 @@ public:
|
|||
/// declared at the given, potentially unknown, location. If the VectorType
|
||||
/// defined by the arguments would be ill-formed, emit errors and return
|
||||
/// nullptr-wrapping type.
|
||||
static VectorType getChecked(ArrayRef<int64_t> shape, Type elementType,
|
||||
Location location);
|
||||
static VectorType getChecked(Location location, ArrayRef<int64_t> shape,
|
||||
Type elementType);
|
||||
|
||||
/// Verify the construction of a vector type.
|
||||
static LogicalResult verifyConstructionInvariants(Location loc,
|
||||
|
@ -394,8 +394,8 @@ public:
|
|||
/// type declared at the given, potentially unknown, location. If the
|
||||
/// RankedTensorType defined by the arguments would be ill-formed, emit errors
|
||||
/// and return a nullptr-wrapping type.
|
||||
static RankedTensorType getChecked(ArrayRef<int64_t> shape, Type elementType,
|
||||
Location location);
|
||||
static RankedTensorType getChecked(Location location, ArrayRef<int64_t> shape,
|
||||
Type elementType);
|
||||
|
||||
/// Verify the construction of a ranked tensor type.
|
||||
static LogicalResult verifyConstructionInvariants(Location loc,
|
||||
|
@ -424,7 +424,7 @@ public:
|
|||
/// type declared at the given, potentially unknown, location. If the
|
||||
/// UnrankedTensorType defined by the arguments would be ill-formed, emit
|
||||
/// errors and return a nullptr-wrapping type.
|
||||
static UnrankedTensorType getChecked(Type elementType, Location location);
|
||||
static UnrankedTensorType getChecked(Location location, Type elementType);
|
||||
|
||||
/// Verify the construction of a unranked tensor type.
|
||||
static LogicalResult verifyConstructionInvariants(Location loc,
|
||||
|
@ -527,9 +527,10 @@ public:
|
|||
/// UnknownLoc. If the MemRefType defined by the arguments would be
|
||||
/// ill-formed, emits errors (to the handler registered with the context or to
|
||||
/// the error stream) and returns nullptr.
|
||||
static MemRefType getChecked(ArrayRef<int64_t> shape, Type elementType,
|
||||
static MemRefType getChecked(Location location, ArrayRef<int64_t> shape,
|
||||
Type elementType,
|
||||
ArrayRef<AffineMap> affineMapComposition,
|
||||
unsigned memorySpace, Location location);
|
||||
unsigned memorySpace);
|
||||
|
||||
ArrayRef<int64_t> getShape() const;
|
||||
|
||||
|
@ -573,8 +574,8 @@ public:
|
|||
/// type and memory space declared at the given, potentially unknown,
|
||||
/// location. If the UnrankedMemRefType defined by the arguments would be
|
||||
/// ill-formed, emit errors and return a nullptr-wrapping type.
|
||||
static UnrankedMemRefType getChecked(Type elementType, unsigned memorySpace,
|
||||
Location location);
|
||||
static UnrankedMemRefType getChecked(Location location, Type elementType,
|
||||
unsigned memorySpace);
|
||||
|
||||
/// Verify the construction of a unranked memref type.
|
||||
static LogicalResult verifyConstructionInvariants(Location loc,
|
||||
|
@ -600,7 +601,7 @@ public:
|
|||
|
||||
/// Get or create a new TupleType with the provided element types. Assumes the
|
||||
/// arguments define a well-formed type.
|
||||
static TupleType get(TypeRange elementTypes, MLIRContext *context);
|
||||
static TupleType get(MLIRContext *context, TypeRange elementTypes);
|
||||
|
||||
/// Get or create an empty tuple type.
|
||||
static TupleType get(MLIRContext *context);
|
||||
|
|
|
@ -475,8 +475,9 @@ def AnyComplex : Type<CPred<"$_self.isa<::mlir::ComplexType>()">,
|
|||
class OpaqueType<string dialect, string name, string description>
|
||||
: Type<CPred<"isOpaqueTypeWithName($_self, \""#dialect#"\", \""#name#"\")">,
|
||||
description>,
|
||||
BuildableType<"::mlir::OpaqueType::get($_builder.getIdentifier(\""
|
||||
# dialect # "\"), \"" # name # "\", $_builder.getContext())">;
|
||||
BuildableType<"::mlir::OpaqueType::get($_builder.getContext(), "
|
||||
"$_builder.getIdentifier(\"" # dialect # "\"), \""
|
||||
# name # "\")">;
|
||||
|
||||
// Function Type
|
||||
|
||||
|
|
|
@ -26,15 +26,15 @@ bool mlirTypeIsAInteger(MlirType type) {
|
|||
}
|
||||
|
||||
MlirType mlirIntegerTypeGet(MlirContext ctx, unsigned bitwidth) {
|
||||
return wrap(IntegerType::get(bitwidth, unwrap(ctx)));
|
||||
return wrap(IntegerType::get(unwrap(ctx), bitwidth));
|
||||
}
|
||||
|
||||
MlirType mlirIntegerTypeSignedGet(MlirContext ctx, unsigned bitwidth) {
|
||||
return wrap(IntegerType::get(bitwidth, IntegerType::Signed, unwrap(ctx)));
|
||||
return wrap(IntegerType::get(unwrap(ctx), bitwidth, IntegerType::Signed));
|
||||
}
|
||||
|
||||
MlirType mlirIntegerTypeUnsignedGet(MlirContext ctx, unsigned bitwidth) {
|
||||
return wrap(IntegerType::get(bitwidth, IntegerType::Unsigned, unwrap(ctx)));
|
||||
return wrap(IntegerType::get(unwrap(ctx), bitwidth, IntegerType::Unsigned));
|
||||
}
|
||||
|
||||
unsigned mlirIntegerTypeGetWidth(MlirType type) {
|
||||
|
@ -172,8 +172,8 @@ MlirType mlirVectorTypeGet(intptr_t rank, const int64_t *shape,
|
|||
MlirType mlirVectorTypeGetChecked(intptr_t rank, const int64_t *shape,
|
||||
MlirType elementType, MlirLocation loc) {
|
||||
return wrap(VectorType::getChecked(
|
||||
llvm::makeArrayRef(shape, static_cast<size_t>(rank)), unwrap(elementType),
|
||||
unwrap(loc)));
|
||||
unwrap(loc), llvm::makeArrayRef(shape, static_cast<size_t>(rank)),
|
||||
unwrap(elementType)));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -201,8 +201,8 @@ MlirType mlirRankedTensorTypeGetChecked(intptr_t rank, const int64_t *shape,
|
|||
MlirType elementType,
|
||||
MlirLocation loc) {
|
||||
return wrap(RankedTensorType::getChecked(
|
||||
llvm::makeArrayRef(shape, static_cast<size_t>(rank)), unwrap(elementType),
|
||||
unwrap(loc)));
|
||||
unwrap(loc), llvm::makeArrayRef(shape, static_cast<size_t>(rank)),
|
||||
unwrap(elementType)));
|
||||
}
|
||||
|
||||
MlirType mlirUnrankedTensorTypeGet(MlirType elementType) {
|
||||
|
@ -211,7 +211,7 @@ MlirType mlirUnrankedTensorTypeGet(MlirType elementType) {
|
|||
|
||||
MlirType mlirUnrankedTensorTypeGetChecked(MlirType elementType,
|
||||
MlirLocation loc) {
|
||||
return wrap(UnrankedTensorType::getChecked(unwrap(elementType), unwrap(loc)));
|
||||
return wrap(UnrankedTensorType::getChecked(unwrap(loc), unwrap(elementType)));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -244,8 +244,8 @@ MlirType mlirMemRefTypeContiguousGetChecked(MlirType elementType, intptr_t rank,
|
|||
unsigned memorySpace,
|
||||
MlirLocation loc) {
|
||||
return wrap(MemRefType::getChecked(
|
||||
llvm::makeArrayRef(shape, static_cast<size_t>(rank)), unwrap(elementType),
|
||||
llvm::None, memorySpace, unwrap(loc)));
|
||||
unwrap(loc), llvm::makeArrayRef(shape, static_cast<size_t>(rank)),
|
||||
unwrap(elementType), llvm::None, memorySpace));
|
||||
}
|
||||
|
||||
intptr_t mlirMemRefTypeGetNumAffineMaps(MlirType type) {
|
||||
|
@ -272,8 +272,8 @@ MlirType mlirUnrankedMemRefTypeGet(MlirType elementType, unsigned memorySpace) {
|
|||
MlirType mlirUnrankedMemRefTypeGetChecked(MlirType elementType,
|
||||
unsigned memorySpace,
|
||||
MlirLocation loc) {
|
||||
return wrap(UnrankedMemRefType::getChecked(unwrap(elementType), memorySpace,
|
||||
unwrap(loc)));
|
||||
return wrap(UnrankedMemRefType::getChecked(unwrap(loc), unwrap(elementType),
|
||||
memorySpace));
|
||||
}
|
||||
|
||||
unsigned mlirUnrankedMemrefGetMemorySpace(MlirType type) {
|
||||
|
@ -290,7 +290,7 @@ MlirType mlirTupleTypeGet(MlirContext ctx, intptr_t numElements,
|
|||
MlirType const *elements) {
|
||||
SmallVector<Type, 4> types;
|
||||
ArrayRef<Type> typeRef = unwrapList(numElements, elements, types);
|
||||
return wrap(TupleType::get(typeRef, unwrap(ctx)));
|
||||
return wrap(TupleType::get(unwrap(ctx), typeRef));
|
||||
}
|
||||
|
||||
intptr_t mlirTupleTypeGetNumTypes(MlirType type) {
|
||||
|
@ -316,7 +316,7 @@ MlirType mlirFunctionTypeGet(MlirContext ctx, intptr_t numInputs,
|
|||
SmallVector<Type, 4> resultsList;
|
||||
(void)unwrapList(numInputs, inputs, inputsList);
|
||||
(void)unwrapList(numResults, results, resultsList);
|
||||
return wrap(FunctionType::get(inputsList, resultsList, unwrap(ctx)));
|
||||
return wrap(FunctionType::get(unwrap(ctx), inputsList, resultsList));
|
||||
}
|
||||
|
||||
intptr_t mlirFunctionTypeGetNumInputs(MlirType type) {
|
||||
|
|
|
@ -53,52 +53,52 @@ namespace {
|
|||
struct AsyncAPI {
|
||||
static FunctionType addOrDropRefFunctionType(MLIRContext *ctx) {
|
||||
auto ref = LLVM::LLVMType::getInt8PtrTy(ctx);
|
||||
auto count = IntegerType::get(32, ctx);
|
||||
return FunctionType::get({ref, count}, {}, ctx);
|
||||
auto count = IntegerType::get(ctx, 32);
|
||||
return FunctionType::get(ctx, {ref, count}, {});
|
||||
}
|
||||
|
||||
static FunctionType createTokenFunctionType(MLIRContext *ctx) {
|
||||
return FunctionType::get({}, {TokenType::get(ctx)}, ctx);
|
||||
return FunctionType::get(ctx, {}, {TokenType::get(ctx)});
|
||||
}
|
||||
|
||||
static FunctionType createGroupFunctionType(MLIRContext *ctx) {
|
||||
return FunctionType::get({}, {GroupType::get(ctx)}, ctx);
|
||||
return FunctionType::get(ctx, {}, {GroupType::get(ctx)});
|
||||
}
|
||||
|
||||
static FunctionType emplaceTokenFunctionType(MLIRContext *ctx) {
|
||||
return FunctionType::get({TokenType::get(ctx)}, {}, ctx);
|
||||
return FunctionType::get(ctx, {TokenType::get(ctx)}, {});
|
||||
}
|
||||
|
||||
static FunctionType awaitTokenFunctionType(MLIRContext *ctx) {
|
||||
return FunctionType::get({TokenType::get(ctx)}, {}, ctx);
|
||||
return FunctionType::get(ctx, {TokenType::get(ctx)}, {});
|
||||
}
|
||||
|
||||
static FunctionType awaitGroupFunctionType(MLIRContext *ctx) {
|
||||
return FunctionType::get({GroupType::get(ctx)}, {}, ctx);
|
||||
return FunctionType::get(ctx, {GroupType::get(ctx)}, {});
|
||||
}
|
||||
|
||||
static FunctionType executeFunctionType(MLIRContext *ctx) {
|
||||
auto hdl = LLVM::LLVMType::getInt8PtrTy(ctx);
|
||||
auto resume = resumeFunctionType(ctx).getPointerTo();
|
||||
return FunctionType::get({hdl, resume}, {}, ctx);
|
||||
return FunctionType::get(ctx, {hdl, resume}, {});
|
||||
}
|
||||
|
||||
static FunctionType addTokenToGroupFunctionType(MLIRContext *ctx) {
|
||||
auto i64 = IntegerType::get(64, ctx);
|
||||
return FunctionType::get({TokenType::get(ctx), GroupType::get(ctx)}, {i64},
|
||||
ctx);
|
||||
auto i64 = IntegerType::get(ctx, 64);
|
||||
return FunctionType::get(ctx, {TokenType::get(ctx), GroupType::get(ctx)},
|
||||
{i64});
|
||||
}
|
||||
|
||||
static FunctionType awaitAndExecuteFunctionType(MLIRContext *ctx) {
|
||||
auto hdl = LLVM::LLVMType::getInt8PtrTy(ctx);
|
||||
auto resume = resumeFunctionType(ctx).getPointerTo();
|
||||
return FunctionType::get({TokenType::get(ctx), hdl, resume}, {}, ctx);
|
||||
return FunctionType::get(ctx, {TokenType::get(ctx), hdl, resume}, {});
|
||||
}
|
||||
|
||||
static FunctionType awaitAllAndExecuteFunctionType(MLIRContext *ctx) {
|
||||
auto hdl = LLVM::LLVMType::getInt8PtrTy(ctx);
|
||||
auto resume = resumeFunctionType(ctx).getPointerTo();
|
||||
return FunctionType::get({GroupType::get(ctx), hdl, resume}, {}, ctx);
|
||||
return FunctionType::get(ctx, {GroupType::get(ctx), hdl, resume}, {});
|
||||
}
|
||||
|
||||
// Auxiliary coroutine resume intrinsic wrapper.
|
||||
|
@ -690,7 +690,7 @@ public:
|
|||
if (!addToGroup.operand().getType().isa<TokenType>())
|
||||
return failure();
|
||||
|
||||
auto i64 = IntegerType::get(64, op->getContext());
|
||||
auto i64 = IntegerType::get(op->getContext(), 64);
|
||||
rewriter.replaceOpWithNewOp<CallOp>(op, kAddTokenToGroup, i64, operands);
|
||||
return success();
|
||||
}
|
||||
|
|
|
@ -122,7 +122,7 @@ LogicalResult ConvertGpuLaunchFuncToVulkanLaunchFunc::declareVulkanLaunchFunc(
|
|||
}
|
||||
|
||||
// Declare vulkan launch function.
|
||||
auto funcType = FunctionType::get(vulkanLaunchTypes, {}, loc->getContext());
|
||||
auto funcType = builder.getFunctionType(vulkanLaunchTypes, {});
|
||||
builder.create<FuncOp>(loc, kVulkanLaunch, funcType).setPrivate();
|
||||
|
||||
return success();
|
||||
|
|
|
@ -84,7 +84,7 @@ static LLVMType getPtrToElementType(T containerType,
|
|||
/// };
|
||||
static Type convertRangeType(RangeType t, LLVMTypeConverter &converter) {
|
||||
auto *context = t.getContext();
|
||||
auto int64Ty = converter.convertType(IntegerType::get(64, context))
|
||||
auto int64Ty = converter.convertType(IntegerType::get(context, 64))
|
||||
.cast<LLVM::LLVMType>();
|
||||
return LLVMType::getStructTy(int64Ty, int64Ty, int64Ty);
|
||||
}
|
||||
|
|
|
@ -65,7 +65,7 @@ static FlatSymbolRefAttr getLibraryCallSymbolRef(Operation *op,
|
|||
assert(op->getNumResults() == 0 &&
|
||||
"Library call for linalg operation can be generated only for ops that "
|
||||
"have void return types");
|
||||
auto libFnType = FunctionType::get(inputTypes, {}, rewriter.getContext());
|
||||
auto libFnType = rewriter.getFunctionType(inputTypes, {});
|
||||
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
// Insert before module terminator.
|
||||
|
|
|
@ -407,8 +407,7 @@ public:
|
|||
// cover all possible corner cases.
|
||||
if (isSignedIntegerOrVector(srcType) ||
|
||||
isUnsignedIntegerOrVector(srcType)) {
|
||||
auto *context = rewriter.getContext();
|
||||
auto signlessType = IntegerType::get(getBitWidth(srcType), context);
|
||||
auto signlessType = rewriter.getIntegerType(getBitWidth(srcType));
|
||||
|
||||
if (srcType.isa<VectorType>()) {
|
||||
auto dstElementsAttr = constOp.value().cast<DenseIntElementsAttr>();
|
||||
|
|
|
@ -584,7 +584,7 @@ LogicalResult VectorTransferRewriter<TransferReadOp>::matchAndRewrite(
|
|||
std::swap(ivsStorage.back(), ivsStorage[coalescedIdx]);
|
||||
|
||||
ArrayRef<Value> ivs(ivsStorage);
|
||||
Value pos = std_index_cast(IntegerType::get(32, ctx), ivs.back());
|
||||
Value pos = std_index_cast(IntegerType::get(ctx, 32), ivs.back());
|
||||
Value inVector = local(ivs.drop_back());
|
||||
auto loadValue = [&](ArrayRef<Value> indices) {
|
||||
Value vector = vector_insert_element(remote(indices), inVector, pos);
|
||||
|
@ -671,7 +671,7 @@ LogicalResult VectorTransferRewriter<TransferWriteOp>::matchAndRewrite(
|
|||
|
||||
ArrayRef<Value> ivs(ivsStorage);
|
||||
Value pos =
|
||||
std_index_cast(IntegerType::get(32, op->getContext()), ivs.back());
|
||||
std_index_cast(IntegerType::get(op->getContext(), 32), ivs.back());
|
||||
auto storeValue = [&](ArrayRef<Value> indices) {
|
||||
Value scalar = vector_extract_element(local(ivs.drop_back()), pos);
|
||||
remote(indices) = scalar;
|
||||
|
|
|
@ -152,7 +152,7 @@ void ExecuteOp::build(OpBuilder &builder, OperationState &result,
|
|||
int32_t numDependencies = dependencies.size();
|
||||
int32_t numOperands = operands.size();
|
||||
auto operandSegmentSizes = DenseIntElementsAttr::get(
|
||||
VectorType::get({2}, IntegerType::get(32, result.getContext())),
|
||||
VectorType::get({2}, builder.getIntegerType(32)),
|
||||
{numDependencies, numOperands});
|
||||
result.addAttribute(kOperandSegmentSizesAttr, operandSegmentSizes);
|
||||
|
||||
|
|
|
@ -118,7 +118,7 @@ LogicalResult AsyncRefCountingPass::addAutomaticRefCounting(Value value) {
|
|||
builder.setInsertionPointToStart(value.getParentBlock());
|
||||
|
||||
Location loc = value.getLoc();
|
||||
auto i32 = IntegerType::get(32, ctx);
|
||||
auto i32 = IntegerType::get(ctx, 32);
|
||||
|
||||
// Drop the reference count immediately if the value has no uses.
|
||||
if (value.getUses().empty()) {
|
||||
|
|
|
@ -31,7 +31,7 @@ struct GpuAllReduceRewriter {
|
|||
: funcOp(funcOp_), reduceOp(reduceOp_), rewriter(rewriter_),
|
||||
loc(reduceOp.getLoc()), valueType(reduceOp.value().getType()),
|
||||
indexType(IndexType::get(reduceOp.getContext())),
|
||||
int32Type(IntegerType::get(/*width=*/32, reduceOp.getContext())) {}
|
||||
int32Type(IntegerType::get(reduceOp.getContext(), /*width=*/32)) {}
|
||||
|
||||
/// Creates an all_reduce across the workgroup.
|
||||
///
|
||||
|
|
|
@ -155,7 +155,7 @@ static gpu::GPUFuncOp outlineKernelFuncImpl(gpu::LaunchOp launchOp,
|
|||
kernelOperandTypes.push_back(operand.getType());
|
||||
}
|
||||
FunctionType type =
|
||||
FunctionType::get(kernelOperandTypes, {}, launchOp.getContext());
|
||||
FunctionType::get(launchOp.getContext(), kernelOperandTypes, {});
|
||||
auto outlinedFunc = builder.create<gpu::GPUFuncOp>(loc, kernelFnName, type);
|
||||
outlinedFunc->setAttr(gpu::GPUDialect::getKernelFuncAttrName(),
|
||||
builder.getUnitAttr());
|
||||
|
|
|
@ -120,8 +120,8 @@ static ParseResult parseCmpOp(OpAsmParser &parser, OperationState &result) {
|
|||
static void printAllocaOp(OpAsmPrinter &p, AllocaOp &op) {
|
||||
auto elemTy = op.getType().cast<LLVM::LLVMType>().getPointerElementTy();
|
||||
|
||||
auto funcTy = FunctionType::get({op.arraySize().getType()}, {op.getType()},
|
||||
op.getContext());
|
||||
auto funcTy = FunctionType::get(op.getContext(), {op.arraySize().getType()},
|
||||
{op.getType()});
|
||||
|
||||
p << op.getOperationName() << ' ' << op.arraySize() << " x " << elemTy;
|
||||
if (op.alignment().hasValue() && *op.alignment() != 0)
|
||||
|
@ -781,7 +781,7 @@ static void printCallOp(OpAsmPrinter &p, CallOp &op) {
|
|||
|
||||
// Reconstruct the function MLIR function type from operand and result types.
|
||||
p << " : "
|
||||
<< FunctionType::get(args.getTypes(), op.getResultTypes(), op.getContext());
|
||||
<< FunctionType::get(op.getContext(), args.getTypes(), op.getResultTypes());
|
||||
}
|
||||
|
||||
// <operation> ::= `llvm.call` (function-id | ssa-use) `(` ssa-use-list `)`
|
||||
|
|
|
@ -76,25 +76,25 @@ static Value allocBuffer(const LinalgPromotionOptions &options,
|
|||
IntegerAttr alignment_attr;
|
||||
if (alignment.hasValue())
|
||||
alignment_attr =
|
||||
IntegerAttr::get(IntegerType::get(64, ctx), alignment.getValue());
|
||||
IntegerAttr::get(IntegerType::get(ctx, 64), alignment.getValue());
|
||||
if (!dynamicBuffers)
|
||||
if (auto cst = size.getDefiningOp<ConstantIndexOp>())
|
||||
return options.useAlloca
|
||||
? std_alloca(MemRefType::get(width * cst.getValue(),
|
||||
IntegerType::get(8, ctx)),
|
||||
IntegerType::get(ctx, 8)),
|
||||
ValueRange{}, alignment_attr)
|
||||
.value
|
||||
: std_alloc(MemRefType::get(width * cst.getValue(),
|
||||
IntegerType::get(8, ctx)),
|
||||
IntegerType::get(ctx, 8)),
|
||||
ValueRange{}, alignment_attr)
|
||||
.value;
|
||||
Value mul =
|
||||
folded_std_muli(folder, folded_std_constant_index(folder, width), size);
|
||||
return options.useAlloca
|
||||
? std_alloca(MemRefType::get(-1, IntegerType::get(8, ctx)), mul,
|
||||
? std_alloca(MemRefType::get(-1, IntegerType::get(ctx, 8)), mul,
|
||||
alignment_attr)
|
||||
.value
|
||||
: std_alloc(MemRefType::get(-1, IntegerType::get(8, ctx)), mul,
|
||||
: std_alloc(MemRefType::get(-1, IntegerType::get(ctx, 8)), mul,
|
||||
alignment_attr)
|
||||
.value;
|
||||
}
|
||||
|
|
|
@ -18,7 +18,7 @@ static bool getDefaultStorageParams(unsigned numBits, bool narrowRange,
|
|||
int64_t &qmax) {
|
||||
// Hard-coded type mapping from TFLite.
|
||||
if (numBits <= 8) {
|
||||
storageType = IntegerType::get(8, ctx);
|
||||
storageType = IntegerType::get(ctx, 8);
|
||||
if (isSigned) {
|
||||
qmin = -128;
|
||||
qmax = 127;
|
||||
|
@ -27,7 +27,7 @@ static bool getDefaultStorageParams(unsigned numBits, bool narrowRange,
|
|||
qmax = 255;
|
||||
}
|
||||
} else if (numBits <= 16) {
|
||||
storageType = IntegerType::get(16, ctx);
|
||||
storageType = IntegerType::get(ctx, 16);
|
||||
if (isSigned) {
|
||||
qmin = -32768;
|
||||
qmax = 32767;
|
||||
|
@ -36,7 +36,7 @@ static bool getDefaultStorageParams(unsigned numBits, bool narrowRange,
|
|||
qmax = 65535;
|
||||
}
|
||||
} else if (numBits <= 32) {
|
||||
storageType = IntegerType::get(32, ctx);
|
||||
storageType = IntegerType::get(ctx, 32);
|
||||
if (isSigned) {
|
||||
qmin = std::numeric_limits<int32_t>::min();
|
||||
qmax = std::numeric_limits<int32_t>::max();
|
||||
|
|
|
@ -79,7 +79,7 @@ UniformQuantizedPerAxisValueConverter::convert(DenseFPElementsAttr attr) {
|
|||
int64_t chunkSize =
|
||||
std::accumulate(std::next(shape.begin(), quantizationDim + 1),
|
||||
shape.end(), 1, std::multiplies<int64_t>());
|
||||
Type newElementType = IntegerType::get(storageBitWidth, attr.getContext());
|
||||
Type newElementType = IntegerType::get(attr.getContext(), storageBitWidth);
|
||||
return attr.mapValues(newElementType, [&](const APFloat &old) {
|
||||
int chunkIndex = (flattenIndex++) / chunkSize;
|
||||
return converters[chunkIndex % dimSize].quantizeFloatToInt(old);
|
||||
|
|
|
@ -96,7 +96,7 @@ void mlir::outlineIfOp(OpBuilder &b, scf::IfOp ifOp, FuncOp *thenFn,
|
|||
|
||||
ValueRange values(captures.getArrayRef());
|
||||
FunctionType type =
|
||||
FunctionType::get(values.getTypes(), ifOp.getResultTypes(), ctx);
|
||||
FunctionType::get(ctx, values.getTypes(), ifOp.getResultTypes());
|
||||
auto outlinedFunc = b.create<FuncOp>(loc, funcName, type);
|
||||
b.setInsertionPointToStart(outlinedFunc.addEntryBlock());
|
||||
BlockAndValueMapping bvm;
|
||||
|
|
|
@ -123,7 +123,7 @@ spirv::getEntryPointABIAttr(ArrayRef<int32_t> localSize, MLIRContext *context) {
|
|||
assert(localSize.size() == 3);
|
||||
return spirv::EntryPointABIAttr::get(
|
||||
DenseElementsAttr::get<int32_t>(
|
||||
VectorType::get(3, IntegerType::get(32, context)), localSize)
|
||||
VectorType::get(3, IntegerType::get(context, 32)), localSize)
|
||||
.cast<DenseIntElementsAttr>(),
|
||||
context);
|
||||
}
|
||||
|
|
|
@ -93,7 +93,7 @@ Type SPIRVTypeConverter::getIndexType(MLIRContext *context) {
|
|||
// instructions. The Vulkan spec requires the builtins like
|
||||
// GlobalInvocationID, etc. to be 32-bit (unsigned) integers which should be
|
||||
// SExtended to 64-bit for index computations.
|
||||
return IntegerType::get(32, context);
|
||||
return IntegerType::get(context, 32);
|
||||
}
|
||||
|
||||
/// Mapping between SPIR-V storage classes to memref memory spaces.
|
||||
|
@ -260,8 +260,8 @@ convertScalarType(const spirv::TargetEnv &targetEnv, spirv::ScalarType type,
|
|||
|
||||
auto intType = type.cast<IntegerType>();
|
||||
LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n");
|
||||
return IntegerType::get(/*width=*/32, intType.getSignedness(),
|
||||
targetEnv.getContext());
|
||||
return IntegerType::get(targetEnv.getContext(), /*width=*/32,
|
||||
intType.getSignedness());
|
||||
}
|
||||
|
||||
/// Converts a vector `type` to a suitable type under the given `targetEnv`.
|
||||
|
|
|
@ -714,7 +714,7 @@ LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
|
|||
}
|
||||
|
||||
FunctionType CallOp::getCalleeType() {
|
||||
return FunctionType::get(getOperandTypes(), getResultTypes(), getContext());
|
||||
return FunctionType::get(getContext(), getOperandTypes(), getResultTypes());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -753,7 +753,7 @@ void CallIndirectOp::getCanonicalizationPatterns(
|
|||
|
||||
// Return the type of the same shape (scalar, vector or tensor) containing i1.
|
||||
static Type getI1SameShape(Type type) {
|
||||
auto i1Type = IntegerType::get(1, type.getContext());
|
||||
auto i1Type = IntegerType::get(type.getContext(), 1);
|
||||
if (auto tensorType = type.dyn_cast<RankedTensorType>())
|
||||
return RankedTensorType::get(tensorType.getShape(), i1Type);
|
||||
if (type.isa<UnrankedTensorType>())
|
||||
|
@ -914,7 +914,7 @@ OpFoldResult CmpFOp::fold(ArrayRef<Attribute> operands) {
|
|||
return {};
|
||||
|
||||
auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
|
||||
return IntegerAttr::get(IntegerType::get(1, getContext()), APInt(1, val));
|
||||
return IntegerAttr::get(IntegerType::get(getContext(), 1), APInt(1, val));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -1426,7 +1426,7 @@ static LogicalResult isSumOfIntegerArrayAttrConfinedToShape(
|
|||
static ArrayAttr makeI64ArrayAttr(ArrayRef<int64_t> values,
|
||||
MLIRContext *context) {
|
||||
auto attrs = llvm::map_range(values, [context](int64_t v) -> Attribute {
|
||||
return IntegerAttr::get(IntegerType::get(64, context), APInt(64, v));
|
||||
return IntegerAttr::get(IntegerType::get(context, 64), APInt(64, v));
|
||||
});
|
||||
return ArrayAttr::get(llvm::to_vector<8>(attrs), context);
|
||||
}
|
||||
|
@ -2767,7 +2767,7 @@ static ParseResult parseTupleOp(OpAsmParser &parser, OperationState &result) {
|
|||
parser.parseOptionalAttrDict(result.attributes) ||
|
||||
parser.parseColonTypeList(types) ||
|
||||
parser.resolveOperands(operandInfos, types, loc, result.operands) ||
|
||||
parser.addTypeToList(TupleType::get(types, ctx), result.types));
|
||||
parser.addTypeToList(TupleType::get(ctx, types), result.types));
|
||||
}
|
||||
|
||||
static void print(OpAsmPrinter &p, TupleOp op) {
|
||||
|
|
|
@ -215,7 +215,7 @@ static TupleType generateExtractSlicesOpResultType(VectorType vectorType,
|
|||
// Create Vector type and add to 'vectorTypes[i]'.
|
||||
vectorTypes[i] = VectorType::get(sliceSizes, vectorType.getElementType());
|
||||
}
|
||||
return TupleType::get(vectorTypes, builder.getContext());
|
||||
return builder.getTupleType(vectorTypes);
|
||||
}
|
||||
|
||||
// UnrolledVectorState aggregates per-operand/result vector state required for
|
||||
|
|
|
@ -52,27 +52,27 @@ FloatType Builder::getF64Type() { return FloatType::getF64(context); }
|
|||
|
||||
IndexType Builder::getIndexType() { return IndexType::get(context); }
|
||||
|
||||
IntegerType Builder::getI1Type() { return IntegerType::get(1, context); }
|
||||
IntegerType Builder::getI1Type() { return IntegerType::get(context, 1); }
|
||||
|
||||
IntegerType Builder::getI32Type() { return IntegerType::get(32, context); }
|
||||
IntegerType Builder::getI32Type() { return IntegerType::get(context, 32); }
|
||||
|
||||
IntegerType Builder::getI64Type() { return IntegerType::get(64, context); }
|
||||
IntegerType Builder::getI64Type() { return IntegerType::get(context, 64); }
|
||||
|
||||
IntegerType Builder::getIntegerType(unsigned width) {
|
||||
return IntegerType::get(width, context);
|
||||
return IntegerType::get(context, width);
|
||||
}
|
||||
|
||||
IntegerType Builder::getIntegerType(unsigned width, bool isSigned) {
|
||||
return IntegerType::get(
|
||||
width, isSigned ? IntegerType::Signed : IntegerType::Unsigned, context);
|
||||
context, width, isSigned ? IntegerType::Signed : IntegerType::Unsigned);
|
||||
}
|
||||
|
||||
FunctionType Builder::getFunctionType(TypeRange inputs, TypeRange results) {
|
||||
return FunctionType::get(inputs, results, context);
|
||||
return FunctionType::get(context, inputs, results);
|
||||
}
|
||||
|
||||
TupleType Builder::getTupleType(TypeRange elementTypes) {
|
||||
return TupleType::get(elementTypes, context);
|
||||
return TupleType::get(context, elementTypes);
|
||||
}
|
||||
|
||||
NoneType Builder::getNoneType() { return NoneType::get(context); }
|
||||
|
|
|
@ -179,7 +179,7 @@ FuncOp FuncOp::clone(BlockAndValueMapping &mapper) {
|
|||
for (unsigned i = 0, e = getNumArguments(); i != e; ++i)
|
||||
if (!mapper.contains(getArgument(i)))
|
||||
inputTypes.push_back(newType.getInput(i));
|
||||
newType = FunctionType::get(inputTypes, newType.getResults(), getContext());
|
||||
newType = FunctionType::get(getContext(), inputTypes, newType.getResults());
|
||||
}
|
||||
|
||||
// Create the new function.
|
||||
|
|
|
@ -35,7 +35,7 @@ ComplexType ComplexType::get(Type elementType) {
|
|||
return Base::get(elementType.getContext(), elementType);
|
||||
}
|
||||
|
||||
ComplexType ComplexType::getChecked(Type elementType, Location location) {
|
||||
ComplexType ComplexType::getChecked(Location location, Type elementType) {
|
||||
return Base::getChecked(location, elementType);
|
||||
}
|
||||
|
||||
|
@ -76,7 +76,7 @@ IntegerType::SignednessSemantics IntegerType::getSignedness() const {
|
|||
IntegerType IntegerType::scaleElementBitwidth(unsigned scale) {
|
||||
if (!scale)
|
||||
return IntegerType();
|
||||
return IntegerType::get(scale * getWidth(), getSignedness(), getContext());
|
||||
return IntegerType::get(getContext(), scale * getWidth(), getSignedness());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -126,8 +126,8 @@ FloatType FloatType::scaleElementBitwidth(unsigned scale) {
|
|||
// FunctionType
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
FunctionType FunctionType::get(TypeRange inputs, TypeRange results,
|
||||
MLIRContext *context) {
|
||||
FunctionType FunctionType::get(MLIRContext *context, TypeRange inputs,
|
||||
TypeRange results) {
|
||||
return Base::get(context, inputs, results);
|
||||
}
|
||||
|
||||
|
@ -182,20 +182,20 @@ FunctionType::getWithoutArgsAndResults(ArrayRef<unsigned> argIndices,
|
|||
newResultTypes = newResultTypesBuffer;
|
||||
}
|
||||
|
||||
return get(newInputTypes, newResultTypes, getContext());
|
||||
return get(getContext(), newInputTypes, newResultTypes);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// OpaqueType
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpaqueType OpaqueType::get(Identifier dialect, StringRef typeData,
|
||||
MLIRContext *context) {
|
||||
OpaqueType OpaqueType::get(MLIRContext *context, Identifier dialect,
|
||||
StringRef typeData) {
|
||||
return Base::get(context, dialect, typeData);
|
||||
}
|
||||
|
||||
OpaqueType OpaqueType::getChecked(Identifier dialect, StringRef typeData,
|
||||
MLIRContext *context, Location location) {
|
||||
OpaqueType OpaqueType::getChecked(Location location, Identifier dialect,
|
||||
StringRef typeData) {
|
||||
return Base::getChecked(location, dialect, typeData);
|
||||
}
|
||||
|
||||
|
@ -313,8 +313,8 @@ VectorType VectorType::get(ArrayRef<int64_t> shape, Type elementType) {
|
|||
return Base::get(elementType.getContext(), shape, elementType);
|
||||
}
|
||||
|
||||
VectorType VectorType::getChecked(ArrayRef<int64_t> shape, Type elementType,
|
||||
Location location) {
|
||||
VectorType VectorType::getChecked(Location location, ArrayRef<int64_t> shape,
|
||||
Type elementType) {
|
||||
return Base::getChecked(location, shape, elementType);
|
||||
}
|
||||
|
||||
|
@ -379,9 +379,9 @@ RankedTensorType RankedTensorType::get(ArrayRef<int64_t> shape,
|
|||
return Base::get(elementType.getContext(), shape, elementType);
|
||||
}
|
||||
|
||||
RankedTensorType RankedTensorType::getChecked(ArrayRef<int64_t> shape,
|
||||
Type elementType,
|
||||
Location location) {
|
||||
RankedTensorType RankedTensorType::getChecked(Location location,
|
||||
ArrayRef<int64_t> shape,
|
||||
Type elementType) {
|
||||
return Base::getChecked(location, shape, elementType);
|
||||
}
|
||||
|
||||
|
@ -406,8 +406,8 @@ UnrankedTensorType UnrankedTensorType::get(Type elementType) {
|
|||
return Base::get(elementType.getContext(), elementType);
|
||||
}
|
||||
|
||||
UnrankedTensorType UnrankedTensorType::getChecked(Type elementType,
|
||||
Location location) {
|
||||
UnrankedTensorType UnrankedTensorType::getChecked(Location location,
|
||||
Type elementType) {
|
||||
return Base::getChecked(location, elementType);
|
||||
}
|
||||
|
||||
|
@ -448,9 +448,10 @@ MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
|
|||
/// UnknownLoc. If the MemRefType defined by the arguments would be
|
||||
/// ill-formed, emits errors (to the handler registered with the context or to
|
||||
/// the error stream) and returns nullptr.
|
||||
MemRefType MemRefType::getChecked(ArrayRef<int64_t> shape, Type elementType,
|
||||
MemRefType MemRefType::getChecked(Location location, ArrayRef<int64_t> shape,
|
||||
Type elementType,
|
||||
ArrayRef<AffineMap> affineMapComposition,
|
||||
unsigned memorySpace, Location location) {
|
||||
unsigned memorySpace) {
|
||||
return getImpl(shape, elementType, affineMapComposition, memorySpace,
|
||||
location);
|
||||
}
|
||||
|
@ -524,9 +525,9 @@ UnrankedMemRefType UnrankedMemRefType::get(Type elementType,
|
|||
return Base::get(elementType.getContext(), elementType, memorySpace);
|
||||
}
|
||||
|
||||
UnrankedMemRefType UnrankedMemRefType::getChecked(Type elementType,
|
||||
unsigned memorySpace,
|
||||
Location location) {
|
||||
UnrankedMemRefType UnrankedMemRefType::getChecked(Location location,
|
||||
Type elementType,
|
||||
unsigned memorySpace) {
|
||||
return Base::getChecked(location, elementType, memorySpace);
|
||||
}
|
||||
|
||||
|
@ -694,12 +695,12 @@ LogicalResult mlir::getStridesAndOffset(MemRefType t,
|
|||
|
||||
/// Get or create a new TupleType with the provided element types. Assumes the
|
||||
/// arguments define a well-formed type.
|
||||
TupleType TupleType::get(TypeRange elementTypes, MLIRContext *context) {
|
||||
TupleType TupleType::get(MLIRContext *context, TypeRange elementTypes) {
|
||||
return Base::get(context, elementTypes);
|
||||
}
|
||||
|
||||
/// Get or create an empty tuple type.
|
||||
TupleType TupleType::get(MLIRContext *context) { return get({}, context); }
|
||||
TupleType TupleType::get(MLIRContext *context) { return get(context, {}); }
|
||||
|
||||
/// Return the elements types for this tuple.
|
||||
ArrayRef<Type> TupleType::getTypes() const { return getImpl()->getTypes(); }
|
||||
|
|
|
@ -82,7 +82,7 @@ Type Dialect::parseType(DialectAsmParser &parser) const {
|
|||
// If this dialect allows unknown types, then represent this with OpaqueType.
|
||||
if (allowsUnknownTypes()) {
|
||||
auto ns = Identifier::get(getNamespace(), getContext());
|
||||
return OpaqueType::get(ns, parser.getFullSymbolSpec(), getContext());
|
||||
return OpaqueType::get(getContext(), ns, parser.getFullSymbolSpec());
|
||||
}
|
||||
|
||||
parser.emitError(parser.getNameLoc())
|
||||
|
|
|
@ -772,25 +772,23 @@ getCachedIntegerType(unsigned width,
|
|||
}
|
||||
}
|
||||
|
||||
IntegerType IntegerType::get(unsigned width, MLIRContext *context) {
|
||||
return get(width, IntegerType::Signless, context);
|
||||
IntegerType IntegerType::get(MLIRContext *context, unsigned width) {
|
||||
return get(context, width, IntegerType::Signless);
|
||||
}
|
||||
|
||||
IntegerType IntegerType::get(unsigned width,
|
||||
IntegerType::SignednessSemantics signedness,
|
||||
MLIRContext *context) {
|
||||
IntegerType IntegerType::get(MLIRContext *context, unsigned width,
|
||||
IntegerType::SignednessSemantics signedness) {
|
||||
if (auto cached = getCachedIntegerType(width, signedness, context))
|
||||
return cached;
|
||||
return Base::get(context, width, signedness);
|
||||
}
|
||||
|
||||
IntegerType IntegerType::getChecked(unsigned width, Location location) {
|
||||
return getChecked(width, IntegerType::Signless, location);
|
||||
IntegerType IntegerType::getChecked(Location location, unsigned width) {
|
||||
return getChecked(location, width, IntegerType::Signless);
|
||||
}
|
||||
|
||||
IntegerType IntegerType::getChecked(unsigned width,
|
||||
SignednessSemantics signedness,
|
||||
Location location) {
|
||||
IntegerType IntegerType::getChecked(Location location, unsigned width,
|
||||
SignednessSemantics signedness) {
|
||||
if (auto cached =
|
||||
getCachedIntegerType(width, signedness, location->getContext()))
|
||||
return cached;
|
||||
|
|
|
@ -178,7 +178,7 @@ Operation::Operation(Location location, OperationName name,
|
|||
if (hasSingleResult)
|
||||
resultType = resultTypes.front();
|
||||
else
|
||||
resultType = TupleType::get(resultTypes, location->getContext());
|
||||
resultType = TupleType::get(location->getContext(), resultTypes);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -63,7 +63,7 @@ void Value::setType(Type newType) {
|
|||
return;
|
||||
auto newTypes = llvm::to_vector<4>(curTypes);
|
||||
newTypes[resultNo] = newType;
|
||||
owner->resultType = TupleType::get(newTypes, newType.getContext());
|
||||
owner->resultType = TupleType::get(newType.getContext(), newTypes);
|
||||
}
|
||||
|
||||
/// If this value is the result of an Operation, return the operation that
|
||||
|
|
|
@ -563,8 +563,8 @@ Type Parser::parseExtendedType() {
|
|||
|
||||
// Otherwise, form a new opaque type.
|
||||
return OpaqueType::getChecked(
|
||||
Identifier::get(dialectName, state.context), symbolData,
|
||||
state.context, getEncodedSourceLocation(loc));
|
||||
getEncodedSourceLocation(loc),
|
||||
Identifier::get(dialectName, state.context), symbolData);
|
||||
});
|
||||
}
|
||||
|
||||
|
|
|
@ -338,7 +338,7 @@ Type Parser::parseNonFunctionType() {
|
|||
signSemantics = *signedness ? IntegerType::Signed : IntegerType::Unsigned;
|
||||
|
||||
consumeToken(Token::inttype);
|
||||
return IntegerType::get(width.getValue(), signSemantics, getContext());
|
||||
return IntegerType::get(getContext(), width.getValue(), signSemantics);
|
||||
}
|
||||
|
||||
// float-type
|
||||
|
@ -432,7 +432,7 @@ Type Parser::parseTupleType() {
|
|||
parseToken(Token::greater, "expected '>' in tuple type"))
|
||||
return nullptr;
|
||||
|
||||
return TupleType::get(types, getContext());
|
||||
return TupleType::get(getContext(), types);
|
||||
}
|
||||
|
||||
/// Parse a vector type.
|
||||
|
|
|
@ -236,7 +236,7 @@ Type Importer::getStdTypeForAttr(LLVMType type) {
|
|||
Attribute Importer::getConstantAsAttr(llvm::Constant *value) {
|
||||
if (auto *ci = dyn_cast<llvm::ConstantInt>(value))
|
||||
return b.getIntegerAttr(
|
||||
IntegerType::get(ci->getType()->getBitWidth(), context),
|
||||
IntegerType::get(context, ci->getType()->getBitWidth()),
|
||||
ci->getValue());
|
||||
if (auto *c = dyn_cast<llvm::ConstantDataArray>(value))
|
||||
if (c->isString())
|
||||
|
|
|
@ -1182,7 +1182,7 @@ LogicalResult Deserializer::processType(spirv::Opcode opcode,
|
|||
// signless semantics for such cases.
|
||||
auto sign = operands[2] == 1 ? IntegerType::SignednessSemantics::Signed
|
||||
: IntegerType::SignednessSemantics::Signless;
|
||||
typeMap[operands[0]] = IntegerType::get(operands[1], sign, context);
|
||||
typeMap[operands[0]] = IntegerType::get(context, operands[1], sign);
|
||||
} break;
|
||||
case spirv::Opcode::OpTypeFloat: {
|
||||
if (operands.size() != 2)
|
||||
|
@ -1345,7 +1345,7 @@ LogicalResult Deserializer::processFunctionType(ArrayRef<uint32_t> operands) {
|
|||
if (!isVoidType(returnType)) {
|
||||
returnTypes = llvm::makeArrayRef(returnType);
|
||||
}
|
||||
typeMap[operands[0]] = FunctionType::get(argTypes, returnTypes, context);
|
||||
typeMap[operands[0]] = FunctionType::get(context, argTypes, returnTypes);
|
||||
return success();
|
||||
}
|
||||
|
||||
|
|
|
@ -1267,7 +1267,7 @@ LogicalResult Serializer::prepareBasicType(
|
|||
}
|
||||
typeEnum = spirv::Opcode::OpTypeCooperativeMatrixNV;
|
||||
auto getConstantOp = [&](uint32_t id) {
|
||||
auto attr = IntegerAttr::get(IntegerType::get(32, type.getContext()), id);
|
||||
auto attr = IntegerAttr::get(IntegerType::get(type.getContext(), 32), id);
|
||||
return prepareConstantInt(loc, attr);
|
||||
};
|
||||
operands.push_back(elementTypeID);
|
||||
|
|
|
@ -35,8 +35,8 @@ static void updateFuncOp(FuncOp func,
|
|||
// Add the new arguments to the function type.
|
||||
auto newArgTypes = llvm::to_vector<6>(
|
||||
llvm::concat<const Type>(functionType.getInputs(), erasedResultTypes));
|
||||
auto newFunctionType = FunctionType::get(
|
||||
newArgTypes, functionType.getResults(), func.getContext());
|
||||
auto newFunctionType = FunctionType::get(func.getContext(), newArgTypes,
|
||||
functionType.getResults());
|
||||
func.setType(newFunctionType);
|
||||
|
||||
// Transfer the result attributes to arg attributes.
|
||||
|
|
|
@ -230,9 +230,8 @@ void NormalizeMemRefs::updateFunctionSignature(FuncOp funcOp,
|
|||
|
||||
// We create a new function type and modify the function signature with this
|
||||
// new type.
|
||||
newFuncType = FunctionType::get(/*inputs=*/argTypes,
|
||||
/*results=*/resultTypes,
|
||||
/*context=*/&getContext());
|
||||
newFuncType = FunctionType::get(&getContext(), /*inputs=*/argTypes,
|
||||
/*results=*/resultTypes);
|
||||
}
|
||||
|
||||
// Since we update the function signature, it might affect the result types at
|
||||
|
@ -463,9 +462,9 @@ void NormalizeMemRefs::normalizeFuncOpMemRefs(FuncOp funcOp,
|
|||
continue;
|
||||
}
|
||||
|
||||
FunctionType newFuncType = FunctionType::get(/*inputs=*/inputTypes,
|
||||
/*results=*/resultTypes,
|
||||
/*context=*/&getContext());
|
||||
FunctionType newFuncType =
|
||||
FunctionType::get(&getContext(), /*inputs=*/inputTypes,
|
||||
/*results=*/resultTypes);
|
||||
// Setting the new function signature for this external function.
|
||||
funcOp.setType(newFuncType);
|
||||
}
|
||||
|
|
|
@ -2522,8 +2522,8 @@ struct FuncOpSignatureConversion : public OpConversionPattern<FuncOp> {
|
|||
|
||||
// Update the function signature in-place.
|
||||
rewriter.updateRootInPlace(funcOp, [&] {
|
||||
funcOp.setType(FunctionType::get(result.getConvertedTypes(), newResults,
|
||||
funcOp.getContext()));
|
||||
funcOp.setType(FunctionType::get(funcOp.getContext(),
|
||||
result.getConvertedTypes(), newResults));
|
||||
});
|
||||
return success();
|
||||
}
|
||||
|
|
|
@ -56,7 +56,7 @@ static FuncOp makeFunction(StringRef name, ArrayRef<Type> results = {},
|
|||
ArrayRef<Type> args = {}) {
|
||||
auto &ctx = globalContext();
|
||||
auto function = FuncOp::create(UnknownLoc::get(&ctx), name,
|
||||
FunctionType::get(args, results, &ctx));
|
||||
FunctionType::get(&ctx, args, results));
|
||||
function.addEntryBlock();
|
||||
return function;
|
||||
}
|
||||
|
@ -277,7 +277,7 @@ TEST_FUNC(builder_blocks) {
|
|||
|
||||
TEST_FUNC(builder_cond_branch) {
|
||||
auto f = makeFunction("builder_cond_branch", {},
|
||||
{IntegerType::get(1, &globalContext())});
|
||||
{IntegerType::get(&globalContext(), 1)});
|
||||
|
||||
OpBuilder builder(f.getBody());
|
||||
ScopedContext scope(builder, f.getLoc());
|
||||
|
@ -390,8 +390,8 @@ TEST_FUNC(insertion_in_block) {
|
|||
|
||||
TEST_FUNC(zero_and_std_sign_extendi_op_i1_to_i8) {
|
||||
using namespace edsc::op;
|
||||
auto i1Type = IntegerType::get(1, &globalContext());
|
||||
auto i8Type = IntegerType::get(8, &globalContext());
|
||||
auto i1Type = IntegerType::get(&globalContext(), 1);
|
||||
auto i8Type = IntegerType::get(&globalContext(), 8);
|
||||
auto memrefType = MemRefType::get({}, i1Type, {}, 0);
|
||||
auto f = makeFunction("zero_and_std_sign_extendi_op", {},
|
||||
{memrefType, memrefType});
|
||||
|
@ -414,7 +414,7 @@ TEST_FUNC(zero_and_std_sign_extendi_op_i1_to_i8) {
|
|||
}
|
||||
|
||||
TEST_FUNC(operator_or) {
|
||||
auto i1Type = IntegerType::get(/*width=*/1, &globalContext());
|
||||
auto i1Type = IntegerType::get(&globalContext(), /*width=*/1);
|
||||
auto f = makeFunction("operator_or", {}, {i1Type, i1Type});
|
||||
|
||||
OpBuilder builder(f.getBody());
|
||||
|
@ -435,7 +435,7 @@ TEST_FUNC(operator_or) {
|
|||
}
|
||||
|
||||
TEST_FUNC(operator_and) {
|
||||
auto i1Type = IntegerType::get(/*width=*/1, &globalContext());
|
||||
auto i1Type = IntegerType::get(&globalContext(), /*width=*/1);
|
||||
auto f = makeFunction("operator_and", {}, {i1Type, i1Type});
|
||||
|
||||
OpBuilder builder(f.getBody());
|
||||
|
@ -536,7 +536,7 @@ TEST_FUNC(fptrunc_f32_bf16) {
|
|||
|
||||
TEST_FUNC(select_op_i32) {
|
||||
using namespace edsc::op;
|
||||
auto i32Type = IntegerType::get(32, &globalContext());
|
||||
auto i32Type = IntegerType::get(&globalContext(), 32);
|
||||
auto memrefType = MemRefType::get(
|
||||
{ShapedType::kDynamicSize, ShapedType::kDynamicSize}, i32Type, {}, 0);
|
||||
auto f = makeFunction("select_op", {}, {memrefType});
|
||||
|
|
|
@ -653,7 +653,7 @@ LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents(
|
|||
}
|
||||
int64_t dim =
|
||||
sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamicSize;
|
||||
auto type = IntegerType::get(17, context);
|
||||
auto type = IntegerType::get(context, 17);
|
||||
inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type));
|
||||
return success();
|
||||
}
|
||||
|
|
|
@ -509,7 +509,7 @@ struct TestTypeConverter : public TypeConverter {
|
|||
|
||||
// Convert I42 to I43.
|
||||
if (t.isInteger(42)) {
|
||||
results.push_back(IntegerType::get(43, t.getContext()));
|
||||
results.push_back(IntegerType::get(t.getContext(), 43));
|
||||
return success();
|
||||
}
|
||||
|
||||
|
|
|
@ -69,9 +69,7 @@ struct TestDecomposeCallGraphTypes
|
|||
Location loc) -> Optional<Value> {
|
||||
if (inputs.size() == 1)
|
||||
return llvm::None;
|
||||
TypeRange TypeRange = inputs.getTypes();
|
||||
SmallVector<Type, 2> types(TypeRange.begin(), TypeRange.end());
|
||||
TupleType tuple = TupleType::get(types, builder.getContext());
|
||||
TupleType tuple = builder.getTupleType(inputs.getTypes());
|
||||
Value value = builder.create<test::MakeTupleOp>(loc, tuple, inputs);
|
||||
return value;
|
||||
});
|
||||
|
|
|
@ -59,7 +59,7 @@ ElementsAttr getTestSparseElementsAttr(MLIRContext *ctx,
|
|||
} else {
|
||||
tensorType = RankedTensorType::get(shape, eleType);
|
||||
}
|
||||
auto indicesType = RankedTensorType::get({1, 2}, IntegerType::get(64, ctx));
|
||||
auto indicesType = RankedTensorType::get({1, 2}, IntegerType::get(ctx, 64));
|
||||
auto indices =
|
||||
DenseIntElementsAttr::get(indicesType, {APInt(64, 0), APInt(64, 0)});
|
||||
auto valuesType = RankedTensorType::get({1}, eleType);
|
||||
|
@ -77,7 +77,7 @@ UniformQuantizedType getTestQuantizedType(Type storageType, MLIRContext *ctx) {
|
|||
TEST(QuantizationUtilsTest, convertFloatAttrUniform) {
|
||||
MLIRContext ctx;
|
||||
ctx.getOrLoadDialect<QuantizationDialect>();
|
||||
IntegerType convertedType = IntegerType::get(8, &ctx);
|
||||
IntegerType convertedType = IntegerType::get(&ctx, 8);
|
||||
auto quantizedType = getTestQuantizedType(convertedType, &ctx);
|
||||
TestUniformQuantizedValueConverter converter(quantizedType);
|
||||
|
||||
|
@ -95,7 +95,7 @@ TEST(QuantizationUtilsTest, convertFloatAttrUniform) {
|
|||
TEST(QuantizationUtilsTest, convertRankedDenseAttrUniform) {
|
||||
MLIRContext ctx;
|
||||
ctx.getOrLoadDialect<QuantizationDialect>();
|
||||
IntegerType convertedType = IntegerType::get(8, &ctx);
|
||||
IntegerType convertedType = IntegerType::get(&ctx, 8);
|
||||
auto quantizedType = getTestQuantizedType(convertedType, &ctx);
|
||||
TestUniformQuantizedValueConverter converter(quantizedType);
|
||||
auto realValue = getTestElementsAttr<DenseElementsAttr, ArrayRef<Attribute>>(
|
||||
|
@ -120,7 +120,7 @@ TEST(QuantizationUtilsTest, convertRankedDenseAttrUniform) {
|
|||
TEST(QuantizationUtilsTest, convertRankedSplatAttrUniform) {
|
||||
MLIRContext ctx;
|
||||
ctx.getOrLoadDialect<QuantizationDialect>();
|
||||
IntegerType convertedType = IntegerType::get(8, &ctx);
|
||||
IntegerType convertedType = IntegerType::get(&ctx, 8);
|
||||
auto quantizedType = getTestQuantizedType(convertedType, &ctx);
|
||||
TestUniformQuantizedValueConverter converter(quantizedType);
|
||||
auto realValue = getTestElementsAttr<DenseElementsAttr, Attribute>(
|
||||
|
@ -145,7 +145,7 @@ TEST(QuantizationUtilsTest, convertRankedSplatAttrUniform) {
|
|||
TEST(QuantizationUtilsTest, convertRankedSparseAttrUniform) {
|
||||
MLIRContext ctx;
|
||||
ctx.getOrLoadDialect<QuantizationDialect>();
|
||||
IntegerType convertedType = IntegerType::get(8, &ctx);
|
||||
IntegerType convertedType = IntegerType::get(&ctx, 8);
|
||||
auto quantizedType = getTestQuantizedType(convertedType, &ctx);
|
||||
TestUniformQuantizedValueConverter converter(quantizedType);
|
||||
auto realValue = getTestSparseElementsAttr(&ctx, {1, 2});
|
||||
|
|
|
@ -33,7 +33,7 @@ static void testSplat(Type eltType, const EltTy &splatElt) {
|
|||
namespace {
|
||||
TEST(DenseSplatTest, BoolSplat) {
|
||||
MLIRContext context;
|
||||
IntegerType boolTy = IntegerType::get(1, &context);
|
||||
IntegerType boolTy = IntegerType::get(&context, 1);
|
||||
RankedTensorType shape = RankedTensorType::get({2, 2}, boolTy);
|
||||
|
||||
// Check that splat is automatically detected for boolean values.
|
||||
|
@ -58,7 +58,7 @@ TEST(DenseSplatTest, LargeBoolSplat) {
|
|||
constexpr int64_t boolCount = 56;
|
||||
|
||||
MLIRContext context;
|
||||
IntegerType boolTy = IntegerType::get(1, &context);
|
||||
IntegerType boolTy = IntegerType::get(&context, 1);
|
||||
RankedTensorType shape = RankedTensorType::get({boolCount}, boolTy);
|
||||
|
||||
// Check that splat is automatically detected for boolean values.
|
||||
|
@ -81,7 +81,7 @@ TEST(DenseSplatTest, LargeBoolSplat) {
|
|||
|
||||
TEST(DenseSplatTest, BoolNonSplat) {
|
||||
MLIRContext context;
|
||||
IntegerType boolTy = IntegerType::get(1, &context);
|
||||
IntegerType boolTy = IntegerType::get(&context, 1);
|
||||
RankedTensorType shape = RankedTensorType::get({6}, boolTy);
|
||||
|
||||
// Check that we properly handle non-splat values.
|
||||
|
@ -94,7 +94,7 @@ TEST(DenseSplatTest, OddIntSplat) {
|
|||
// Test detecting a splat with an odd(non 8-bit) integer bitwidth.
|
||||
MLIRContext context;
|
||||
constexpr size_t intWidth = 19;
|
||||
IntegerType intTy = IntegerType::get(intWidth, &context);
|
||||
IntegerType intTy = IntegerType::get(&context, intWidth);
|
||||
APInt value(intWidth, 10);
|
||||
|
||||
testSplat(intTy, value);
|
||||
|
@ -102,7 +102,7 @@ TEST(DenseSplatTest, OddIntSplat) {
|
|||
|
||||
TEST(DenseSplatTest, Int32Splat) {
|
||||
MLIRContext context;
|
||||
IntegerType intTy = IntegerType::get(32, &context);
|
||||
IntegerType intTy = IntegerType::get(&context, 32);
|
||||
int value = 64;
|
||||
|
||||
testSplat(intTy, value);
|
||||
|
@ -110,7 +110,7 @@ TEST(DenseSplatTest, Int32Splat) {
|
|||
|
||||
TEST(DenseSplatTest, IntAttrSplat) {
|
||||
MLIRContext context;
|
||||
IntegerType intTy = IntegerType::get(85, &context);
|
||||
IntegerType intTy = IntegerType::get(&context, 85);
|
||||
Attribute value = IntegerAttr::get(intTy, 109);
|
||||
|
||||
testSplat(intTy, value);
|
||||
|
@ -151,7 +151,7 @@ TEST(DenseSplatTest, BF16Splat) {
|
|||
TEST(DenseSplatTest, StringSplat) {
|
||||
MLIRContext context;
|
||||
Type stringType =
|
||||
OpaqueType::get(Identifier::get("test", &context), "string", &context);
|
||||
OpaqueType::get(&context, Identifier::get("test", &context), "string");
|
||||
StringRef value = "test-string";
|
||||
testSplat(stringType, value);
|
||||
}
|
||||
|
@ -159,7 +159,7 @@ TEST(DenseSplatTest, StringSplat) {
|
|||
TEST(DenseSplatTest, StringAttrSplat) {
|
||||
MLIRContext context;
|
||||
Type stringType =
|
||||
OpaqueType::get(Identifier::get("test", &context), "string", &context);
|
||||
OpaqueType::get(&context, Identifier::get("test", &context), "string");
|
||||
Attribute stringAttr = StringAttr::get("test-string", stringType);
|
||||
testSplat(stringType, stringAttr);
|
||||
}
|
||||
|
@ -173,7 +173,7 @@ TEST(DenseComplexTest, ComplexFloatSplat) {
|
|||
|
||||
TEST(DenseComplexTest, ComplexIntSplat) {
|
||||
MLIRContext context;
|
||||
ComplexType complexType = ComplexType::get(IntegerType::get(64, &context));
|
||||
ComplexType complexType = ComplexType::get(IntegerType::get(&context, 64));
|
||||
std::complex<int64_t> value(10, 15);
|
||||
testSplat(complexType, value);
|
||||
}
|
||||
|
@ -187,7 +187,7 @@ TEST(DenseComplexTest, ComplexAPFloatSplat) {
|
|||
|
||||
TEST(DenseComplexTest, ComplexAPIntSplat) {
|
||||
MLIRContext context;
|
||||
ComplexType complexType = ComplexType::get(IntegerType::get(64, &context));
|
||||
ComplexType complexType = ComplexType::get(IntegerType::get(&context, 64));
|
||||
std::complex<APInt> value(APInt(64, 10), APInt(64, 15));
|
||||
testSplat(complexType, value);
|
||||
}
|
||||
|
|
|
@ -25,7 +25,7 @@ namespace mlir {
|
|||
/// Helper that returns an example test::TestStruct for testing its
|
||||
/// implementation.
|
||||
static test::TestStruct getTestStruct(mlir::MLIRContext *context) {
|
||||
auto integerType = mlir::IntegerType::get(32, context);
|
||||
auto integerType = mlir::IntegerType::get(context, 32);
|
||||
auto integerAttr = mlir::IntegerAttr::get(integerType, 127);
|
||||
|
||||
auto floatType = mlir::FloatType::getF32(context);
|
||||
|
@ -105,7 +105,7 @@ TEST(StructsGenTest, ClassofBadTypeFalse) {
|
|||
expectedValues.begin(), expectedValues.end() - 1);
|
||||
|
||||
// Add a copy of the last attribute with the wrong type.
|
||||
auto i64Type = mlir::IntegerType::get(64, &context);
|
||||
auto i64Type = mlir::IntegerType::get(&context, 64);
|
||||
auto elementsType = mlir::RankedTensorType::get({3}, i64Type);
|
||||
auto elementsAttr =
|
||||
mlir::DenseIntElementsAttr::get(elementsType, ArrayRef<int64_t>{1, 2, 3});
|
||||
|
|
Loading…
Reference in New Issue