[mlir][IR][NFC] Move context/location parameters of builtin Type::get methods to the start of the parameter list

This better matches the rest of the infrastructure, is much simpler, and makes it easier to move these types to being declaratively specified.

Differential Revision: https://reviews.llvm.org/D93432
This commit is contained in:
River Riddle 2020-12-17 12:24:45 -08:00
parent 511cfe9441
commit 1b97cdf885
52 changed files with 205 additions and 209 deletions

View File

@ -2176,7 +2176,7 @@ def fir_DispatchOp : fir_Op<"dispatch",
p.printOptionalAttrDict(getAttrs(), {"fn_type", "method"}); p.printOptionalAttrDict(getAttrs(), {"fn_type", "method"});
auto resTy{getResultTypes()}; auto resTy{getResultTypes()};
llvm::SmallVector<mlir::Type, 8> argTy(getOperandTypes()); llvm::SmallVector<mlir::Type, 8> argTy(getOperandTypes());
p << " : " << mlir::FunctionType::get(argTy, resTy, getContext()); p << " : " << mlir::FunctionType::get(getContext(), argTy, resTy);
}]; }];
let extraClassDeclaration = [{ let extraClassDeclaration = [{

View File

@ -49,7 +49,7 @@ mlir::Type genFIRType(mlir::MLIRContext *context) {
if constexpr (TC == Fortran::common::TypeCategory::Integer) { if constexpr (TC == Fortran::common::TypeCategory::Integer) {
auto bits{Fortran::evaluate::Type<Fortran::common::TypeCategory::Integer, auto bits{Fortran::evaluate::Type<Fortran::common::TypeCategory::Integer,
KIND>::Scalar::bits}; KIND>::Scalar::bits};
return mlir::IntegerType::get(bits, context); return mlir::IntegerType::get(context, bits);
} else if constexpr (TC == Fortran::common::TypeCategory::Logical || } else if constexpr (TC == Fortran::common::TypeCategory::Logical ||
TC == Fortran::common::TypeCategory::Character || TC == Fortran::common::TypeCategory::Character ||
TC == Fortran::common::TypeCategory::Complex) { TC == Fortran::common::TypeCategory::Complex) {
@ -278,7 +278,7 @@ private:
// some sequence of `n` bytes // some sequence of `n` bytes
mlir::Type gen(const Fortran::evaluate::StaticDataObject::Pointer &ptr) { mlir::Type gen(const Fortran::evaluate::StaticDataObject::Pointer &ptr) {
mlir::Type byteTy{mlir::IntegerType::get(8, context)}; mlir::Type byteTy{mlir::IntegerType::get(context, 8)};
return fir::SequenceType::get(trivialShape(ptr->itemBytes()), byteTy); return fir::SequenceType::get(trivialShape(ptr->itemBytes()), byteTy);
} }

View File

@ -298,26 +298,26 @@ static constexpr RuntimeFunction pgmathPrecise[] = {
static mlir::FunctionType genF32F32FuncType(mlir::MLIRContext *context) { static mlir::FunctionType genF32F32FuncType(mlir::MLIRContext *context) {
auto t = mlir::FloatType::getF32(context); auto t = mlir::FloatType::getF32(context);
return mlir::FunctionType::get({t}, {t}, context); return mlir::FunctionType::get(context, {t}, {t});
} }
static mlir::FunctionType genF64F64FuncType(mlir::MLIRContext *context) { static mlir::FunctionType genF64F64FuncType(mlir::MLIRContext *context) {
auto t = mlir::FloatType::getF64(context); auto t = mlir::FloatType::getF64(context);
return mlir::FunctionType::get({t}, {t}, context); return mlir::FunctionType::get(context, {t}, {t});
} }
template <int Bits> template <int Bits>
static mlir::FunctionType genIntF64FuncType(mlir::MLIRContext *context) { static mlir::FunctionType genIntF64FuncType(mlir::MLIRContext *context) {
auto t = mlir::FloatType::getF64(context); auto t = mlir::FloatType::getF64(context);
auto r = mlir::IntegerType::get(Bits, context); auto r = mlir::IntegerType::get(context, Bits);
return mlir::FunctionType::get({t}, {r}, context); return mlir::FunctionType::get(context, {t}, {r});
} }
template <int Bits> template <int Bits>
static mlir::FunctionType genIntF32FuncType(mlir::MLIRContext *context) { static mlir::FunctionType genIntF32FuncType(mlir::MLIRContext *context) {
auto t = mlir::FloatType::getF32(context); auto t = mlir::FloatType::getF32(context);
auto r = mlir::IntegerType::get(Bits, context); auto r = mlir::IntegerType::get(context, Bits);
return mlir::FunctionType::get({t}, {r}, context); return mlir::FunctionType::get(context, {t}, {r});
} }
// TODO : Fill-up this table with more intrinsic. // TODO : Fill-up this table with more intrinsic.
@ -585,8 +585,8 @@ getFunctionType(mlir::Type resultType, llvm::ArrayRef<mlir::Value> arguments,
llvm::SmallVector<mlir::Type, 2> argumentTypes; llvm::SmallVector<mlir::Type, 2> argumentTypes;
for (auto &arg : arguments) for (auto &arg : arguments)
argumentTypes.push_back(arg.getType()); argumentTypes.push_back(arg.getType());
return mlir::FunctionType::get(argumentTypes, resultType, return mlir::FunctionType::get(builder.getModule().getContext(),
builder.getModule().getContext()); argumentTypes, resultType);
} }
/// fir::ExtendedValue to mlir::Value translation layer /// fir::ExtendedValue to mlir::Value translation layer
@ -1144,7 +1144,7 @@ mlir::Value IntrinsicLibrary::genMerge(mlir::Type,
llvm::ArrayRef<mlir::Value> args) { llvm::ArrayRef<mlir::Value> args) {
assert(args.size() == 3); assert(args.size() == 3);
auto i1Type = mlir::IntegerType::get(1, builder.getContext()); auto i1Type = mlir::IntegerType::get(builder.getContext(), 1);
auto mask = builder.createConvert(loc, i1Type, args[2]); auto mask = builder.createConvert(loc, i1Type, args[2]);
return builder.create<mlir::SelectOp>(loc, mask, args[0], args[1]); return builder.create<mlir::SelectOp>(loc, mask, args[0], args[1]);
} }

View File

@ -48,7 +48,7 @@ static constexpr TypeBuilderFunc getModel();
template <> template <>
constexpr TypeBuilderFunc getModel<int>() { constexpr TypeBuilderFunc getModel<int>() {
return [](mlir::MLIRContext *context) -> mlir::Type { return [](mlir::MLIRContext *context) -> mlir::Type {
return mlir::IntegerType::get(8 * sizeof(int), context); return mlir::IntegerType::get(context, 8 * sizeof(int));
}; };
} }
template <> template <>
@ -61,14 +61,14 @@ constexpr TypeBuilderFunc getModel<int &>() {
template <> template <>
constexpr TypeBuilderFunc getModel<Fortran::runtime::io::Iostat>() { constexpr TypeBuilderFunc getModel<Fortran::runtime::io::Iostat>() {
return [](mlir::MLIRContext *context) -> mlir::Type { return [](mlir::MLIRContext *context) -> mlir::Type {
return mlir::IntegerType::get(8 * sizeof(Fortran::runtime::io::Iostat), return mlir::IntegerType::get(context,
context); 8 * sizeof(Fortran::runtime::io::Iostat));
}; };
} }
template <> template <>
constexpr TypeBuilderFunc getModel<char *>() { constexpr TypeBuilderFunc getModel<char *>() {
return [](mlir::MLIRContext *context) -> mlir::Type { return [](mlir::MLIRContext *context) -> mlir::Type {
return fir::ReferenceType::get(mlir::IntegerType::get(8, context)); return fir::ReferenceType::get(mlir::IntegerType::get(context, 8));
}; };
} }
template <> template <>
@ -78,26 +78,26 @@ constexpr TypeBuilderFunc getModel<const char *>() {
template <> template <>
constexpr TypeBuilderFunc getModel<const char16_t *>() { constexpr TypeBuilderFunc getModel<const char16_t *>() {
return [](mlir::MLIRContext *context) -> mlir::Type { return [](mlir::MLIRContext *context) -> mlir::Type {
return fir::ReferenceType::get(mlir::IntegerType::get(16, context)); return fir::ReferenceType::get(mlir::IntegerType::get(context, 16));
}; };
} }
template <> template <>
constexpr TypeBuilderFunc getModel<const char32_t *>() { constexpr TypeBuilderFunc getModel<const char32_t *>() {
return [](mlir::MLIRContext *context) -> mlir::Type { return [](mlir::MLIRContext *context) -> mlir::Type {
return fir::ReferenceType::get(mlir::IntegerType::get(32, context)); return fir::ReferenceType::get(mlir::IntegerType::get(context, 32));
}; };
} }
template <> template <>
constexpr TypeBuilderFunc getModel<void **>() { constexpr TypeBuilderFunc getModel<void **>() {
return [](mlir::MLIRContext *context) -> mlir::Type { return [](mlir::MLIRContext *context) -> mlir::Type {
return fir::ReferenceType::get( return fir::ReferenceType::get(
fir::PointerType::get(mlir::IntegerType::get(8, context))); fir::PointerType::get(mlir::IntegerType::get(context, 8)));
}; };
} }
template <> template <>
constexpr TypeBuilderFunc getModel<std::int64_t>() { constexpr TypeBuilderFunc getModel<std::int64_t>() {
return [](mlir::MLIRContext *context) -> mlir::Type { return [](mlir::MLIRContext *context) -> mlir::Type {
return mlir::IntegerType::get(64, context); return mlir::IntegerType::get(context, 64);
}; };
} }
template <> template <>
@ -110,7 +110,7 @@ constexpr TypeBuilderFunc getModel<std::int64_t &>() {
template <> template <>
constexpr TypeBuilderFunc getModel<std::size_t>() { constexpr TypeBuilderFunc getModel<std::size_t>() {
return [](mlir::MLIRContext *context) -> mlir::Type { return [](mlir::MLIRContext *context) -> mlir::Type {
return mlir::IntegerType::get(8 * sizeof(std::size_t), context); return mlir::IntegerType::get(context, 8 * sizeof(std::size_t));
}; };
} }
template <> template <>
@ -146,7 +146,7 @@ constexpr TypeBuilderFunc getModel<float &>() {
template <> template <>
constexpr TypeBuilderFunc getModel<bool>() { constexpr TypeBuilderFunc getModel<bool>() {
return [](mlir::MLIRContext *context) -> mlir::Type { return [](mlir::MLIRContext *context) -> mlir::Type {
return mlir::IntegerType::get(1, context); return mlir::IntegerType::get(context, 1);
}; };
} }
template <> template <>
@ -190,7 +190,7 @@ struct RuntimeTableKey<RT(ATs...)> {
llvm::SmallVector<mlir::Type, sizeof...(ATs)> argTys; llvm::SmallVector<mlir::Type, sizeof...(ATs)> argTys;
for (auto f : args) for (auto f : args)
argTys.push_back(f(ctxt)); argTys.push_back(f(ctxt));
return mlir::FunctionType::get(argTys, {retTy}, ctxt); return mlir::FunctionType::get(ctxt, argTys, {retTy});
}; };
} }
}; };

View File

@ -151,7 +151,7 @@ mlir::Type fir::BoxDimsOp::getTupleType() {
// note: triple, but 4 is nearest power of 2 // note: triple, but 4 is nearest power of 2
llvm::SmallVector<mlir::Type, 4> triple{ llvm::SmallVector<mlir::Type, 4> triple{
getResult(0).getType(), getResult(1).getType(), getResult(2).getType()}; getResult(0).getType(), getResult(1).getType(), getResult(2).getType()};
return mlir::TupleType::get(triple, getContext()); return mlir::TupleType::get(getContext(), triple);
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -171,7 +171,7 @@ static void printCallOp(mlir::OpAsmPrinter &p, fir::CallOp &op) {
auto resultTypes{op.getResultTypes()}; auto resultTypes{op.getResultTypes()};
llvm::SmallVector<Type, 8> argTypes( llvm::SmallVector<Type, 8> argTypes(
llvm::drop_begin(op.getOperandTypes(), isDirect ? 0 : 1)); llvm::drop_begin(op.getOperandTypes(), isDirect ? 0 : 1));
p << " : " << FunctionType::get(argTypes, resultTypes, op.getContext()); p << " : " << FunctionType::get(op.getContext(), argTypes, resultTypes);
} }
static mlir::ParseResult parseCallOp(mlir::OpAsmParser &parser, static mlir::ParseResult parseCallOp(mlir::OpAsmParser &parser,
@ -1565,4 +1565,3 @@ fir::GlobalOp fir::createGlobalOp(mlir::Location loc, mlir::ModuleOp module,
#define GET_OP_CLASSES #define GET_OP_CLASSES
#include "flang/Optimizer/Dialect/FIROps.cpp.inc" #include "flang/Optimizer/Dialect/FIROps.cpp.inc"

View File

@ -35,8 +35,8 @@ def MaskRndScaleOp : AVX512_Op<"mask.rndscale", [NoSideEffect,
AllTypesMatch<["src", "a", "dst"]>, AllTypesMatch<["src", "a", "dst"]>,
TypesMatchWith<"imm has the same number of bits as elements in dst", TypesMatchWith<"imm has the same number of bits as elements in dst",
"dst", "imm", "dst", "imm",
"IntegerType::get(($_self.cast<VectorType>().getShape()[0])," "IntegerType::get($_self.getContext(), "
" $_self.getContext())">]> { "($_self.cast<VectorType>().getShape()[0]))">]> {
let summary = "Masked roundscale op"; let summary = "Masked roundscale op";
let description = [{ let description = [{
The mask.rndscale op is an AVX512 specific op that can lower to the proper The mask.rndscale op is an AVX512 specific op that can lower to the proper
@ -67,8 +67,8 @@ def MaskScaleFOp : AVX512_Op<"mask.scalef", [NoSideEffect,
AllTypesMatch<["src", "a", "b", "dst"]>, AllTypesMatch<["src", "a", "b", "dst"]>,
TypesMatchWith<"k has the same number of bits as elements in dst", TypesMatchWith<"k has the same number of bits as elements in dst",
"dst", "k", "dst", "k",
"IntegerType::get(($_self.cast<VectorType>().getShape()[0])," "IntegerType::get($_self.getContext(), "
" $_self.getContext())">]> { "($_self.cast<VectorType>().getShape()[0]))">]> {
let summary = "ScaleF op"; let summary = "ScaleF op";
let description = [{ let description = [{
The `mask.scalef` op is an AVX512 specific op that can lower to the proper The `mask.scalef` op is an AVX512 specific op that can lower to the proper

View File

@ -911,7 +911,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
auto attr = (*this)->getAttr("operand_segment_sizes") auto attr = (*this)->getAttr("operand_segment_sizes")
.cast<DenseIntElementsAttr>(); .cast<DenseIntElementsAttr>();
unsigned i = 0; unsigned i = 0;
auto newAttr = attr.mapValues(IntegerType::get(32, getContext()), auto newAttr = attr.mapValues(IntegerType::get(getContext(), 32),
[&](const APInt &v) { return (i++ == idx) ? APInt(32, val) : v; }); [&](const APInt &v) { return (i++ == idx) ? APInt(32, val) : v; });
getOperation()->setAttr("operand_segment_sizes", newAttr); getOperation()->setAttr("operand_segment_sizes", newAttr);
} }

View File

@ -63,7 +63,7 @@ public:
/// Get or create a ComplexType with the provided element type. This emits /// Get or create a ComplexType with the provided element type. This emits
/// and error at the specified location and returns null if the element type /// and error at the specified location and returns null if the element type
/// isn't supported. /// isn't supported.
static ComplexType getChecked(Type elementType, Location location); static ComplexType getChecked(Location location, Type elementType);
/// Verify the construction of an integer type. /// Verify the construction of an integer type.
static LogicalResult verifyConstructionInvariants(Location loc, static LogicalResult verifyConstructionInvariants(Location loc,
@ -93,27 +93,27 @@ public:
/// The created IntegerType is signless (i.e., no signedness semantics). /// The created IntegerType is signless (i.e., no signedness semantics).
/// Assume the width is within the allowed range and assert on failures. Use /// Assume the width is within the allowed range and assert on failures. Use
/// getChecked to handle failures gracefully. /// getChecked to handle failures gracefully.
static IntegerType get(unsigned width, MLIRContext *context); static IntegerType get(MLIRContext *context, unsigned width);
/// Get or create a new IntegerType of the given width within the context. /// Get or create a new IntegerType of the given width within the context.
/// The created IntegerType has signedness semantics as indicated via /// The created IntegerType has signedness semantics as indicated via
/// `signedness`. Assume the width is within the allowed range and assert on /// `signedness`. Assume the width is within the allowed range and assert on
/// failures. Use getChecked to handle failures gracefully. /// failures. Use getChecked to handle failures gracefully.
static IntegerType get(unsigned width, SignednessSemantics signedness, static IntegerType get(MLIRContext *context, unsigned width,
MLIRContext *context); SignednessSemantics signedness);
/// Get or create a new IntegerType of the given width within the context, /// Get or create a new IntegerType of the given width within the context,
/// defined at the given, potentially unknown, location. The created /// defined at the given, potentially unknown, location. The created
/// IntegerType is signless (i.e., no signedness semantics). If the width is /// IntegerType is signless (i.e., no signedness semantics). If the width is
/// outside the allowed range, emit errors and return a null type. /// outside the allowed range, emit errors and return a null type.
static IntegerType getChecked(unsigned width, Location location); static IntegerType getChecked(Location location, unsigned width);
/// Get or create a new IntegerType of the given width within the context, /// Get or create a new IntegerType of the given width within the context,
/// defined at the given, potentially unknown, location. The created /// defined at the given, potentially unknown, location. The created
/// IntegerType has signedness semantics as indicated via `signedness`. If the /// IntegerType has signedness semantics as indicated via `signedness`. If the
/// width is outside the allowed range, emit errors and return a null type. /// width is outside the allowed range, emit errors and return a null type.
static IntegerType getChecked(unsigned width, SignednessSemantics signedness, static IntegerType getChecked(Location location, unsigned width,
Location location); SignednessSemantics signedness);
/// Verify the construction of an integer type. /// Verify the construction of an integer type.
static LogicalResult static LogicalResult
@ -180,8 +180,8 @@ class FunctionType
public: public:
using Base::Base; using Base::Base;
static FunctionType get(TypeRange inputs, TypeRange results, static FunctionType get(MLIRContext *context, TypeRange inputs,
MLIRContext *context); TypeRange results);
/// Input types. /// Input types.
unsigned getNumInputs() const; unsigned getNumInputs() const;
@ -211,14 +211,14 @@ public:
using Base::Base; using Base::Base;
/// Get or create a new OpaqueType with the provided dialect and string data. /// Get or create a new OpaqueType with the provided dialect and string data.
static OpaqueType get(Identifier dialect, StringRef typeData, static OpaqueType get(MLIRContext *context, Identifier dialect,
MLIRContext *context); StringRef typeData);
/// Get or create a new OpaqueType with the provided dialect and string data. /// Get or create a new OpaqueType with the provided dialect and string data.
/// If the given identifier is not a valid namespace for a dialect, then a /// If the given identifier is not a valid namespace for a dialect, then a
/// null type is returned. /// null type is returned.
static OpaqueType getChecked(Identifier dialect, StringRef typeData, static OpaqueType getChecked(Location location, Identifier dialect,
MLIRContext *context, Location location); StringRef typeData);
/// Returns the dialect namespace of the opaque type. /// Returns the dialect namespace of the opaque type.
Identifier getDialectNamespace() const; Identifier getDialectNamespace() const;
@ -335,8 +335,8 @@ public:
/// declared at the given, potentially unknown, location. If the VectorType /// declared at the given, potentially unknown, location. If the VectorType
/// defined by the arguments would be ill-formed, emit errors and return /// defined by the arguments would be ill-formed, emit errors and return
/// nullptr-wrapping type. /// nullptr-wrapping type.
static VectorType getChecked(ArrayRef<int64_t> shape, Type elementType, static VectorType getChecked(Location location, ArrayRef<int64_t> shape,
Location location); Type elementType);
/// Verify the construction of a vector type. /// Verify the construction of a vector type.
static LogicalResult verifyConstructionInvariants(Location loc, static LogicalResult verifyConstructionInvariants(Location loc,
@ -394,8 +394,8 @@ public:
/// type declared at the given, potentially unknown, location. If the /// type declared at the given, potentially unknown, location. If the
/// RankedTensorType defined by the arguments would be ill-formed, emit errors /// RankedTensorType defined by the arguments would be ill-formed, emit errors
/// and return a nullptr-wrapping type. /// and return a nullptr-wrapping type.
static RankedTensorType getChecked(ArrayRef<int64_t> shape, Type elementType, static RankedTensorType getChecked(Location location, ArrayRef<int64_t> shape,
Location location); Type elementType);
/// Verify the construction of a ranked tensor type. /// Verify the construction of a ranked tensor type.
static LogicalResult verifyConstructionInvariants(Location loc, static LogicalResult verifyConstructionInvariants(Location loc,
@ -424,7 +424,7 @@ public:
/// type declared at the given, potentially unknown, location. If the /// type declared at the given, potentially unknown, location. If the
/// UnrankedTensorType defined by the arguments would be ill-formed, emit /// UnrankedTensorType defined by the arguments would be ill-formed, emit
/// errors and return a nullptr-wrapping type. /// errors and return a nullptr-wrapping type.
static UnrankedTensorType getChecked(Type elementType, Location location); static UnrankedTensorType getChecked(Location location, Type elementType);
/// Verify the construction of a unranked tensor type. /// Verify the construction of a unranked tensor type.
static LogicalResult verifyConstructionInvariants(Location loc, static LogicalResult verifyConstructionInvariants(Location loc,
@ -527,9 +527,10 @@ public:
/// UnknownLoc. If the MemRefType defined by the arguments would be /// UnknownLoc. If the MemRefType defined by the arguments would be
/// ill-formed, emits errors (to the handler registered with the context or to /// ill-formed, emits errors (to the handler registered with the context or to
/// the error stream) and returns nullptr. /// the error stream) and returns nullptr.
static MemRefType getChecked(ArrayRef<int64_t> shape, Type elementType, static MemRefType getChecked(Location location, ArrayRef<int64_t> shape,
Type elementType,
ArrayRef<AffineMap> affineMapComposition, ArrayRef<AffineMap> affineMapComposition,
unsigned memorySpace, Location location); unsigned memorySpace);
ArrayRef<int64_t> getShape() const; ArrayRef<int64_t> getShape() const;
@ -573,8 +574,8 @@ public:
/// type and memory space declared at the given, potentially unknown, /// type and memory space declared at the given, potentially unknown,
/// location. If the UnrankedMemRefType defined by the arguments would be /// location. If the UnrankedMemRefType defined by the arguments would be
/// ill-formed, emit errors and return a nullptr-wrapping type. /// ill-formed, emit errors and return a nullptr-wrapping type.
static UnrankedMemRefType getChecked(Type elementType, unsigned memorySpace, static UnrankedMemRefType getChecked(Location location, Type elementType,
Location location); unsigned memorySpace);
/// Verify the construction of a unranked memref type. /// Verify the construction of a unranked memref type.
static LogicalResult verifyConstructionInvariants(Location loc, static LogicalResult verifyConstructionInvariants(Location loc,
@ -600,7 +601,7 @@ public:
/// Get or create a new TupleType with the provided element types. Assumes the /// Get or create a new TupleType with the provided element types. Assumes the
/// arguments define a well-formed type. /// arguments define a well-formed type.
static TupleType get(TypeRange elementTypes, MLIRContext *context); static TupleType get(MLIRContext *context, TypeRange elementTypes);
/// Get or create an empty tuple type. /// Get or create an empty tuple type.
static TupleType get(MLIRContext *context); static TupleType get(MLIRContext *context);

View File

@ -475,8 +475,9 @@ def AnyComplex : Type<CPred<"$_self.isa<::mlir::ComplexType>()">,
class OpaqueType<string dialect, string name, string description> class OpaqueType<string dialect, string name, string description>
: Type<CPred<"isOpaqueTypeWithName($_self, \""#dialect#"\", \""#name#"\")">, : Type<CPred<"isOpaqueTypeWithName($_self, \""#dialect#"\", \""#name#"\")">,
description>, description>,
BuildableType<"::mlir::OpaqueType::get($_builder.getIdentifier(\"" BuildableType<"::mlir::OpaqueType::get($_builder.getContext(), "
# dialect # "\"), \"" # name # "\", $_builder.getContext())">; "$_builder.getIdentifier(\"" # dialect # "\"), \""
# name # "\")">;
// Function Type // Function Type

View File

@ -26,15 +26,15 @@ bool mlirTypeIsAInteger(MlirType type) {
} }
MlirType mlirIntegerTypeGet(MlirContext ctx, unsigned bitwidth) { MlirType mlirIntegerTypeGet(MlirContext ctx, unsigned bitwidth) {
return wrap(IntegerType::get(bitwidth, unwrap(ctx))); return wrap(IntegerType::get(unwrap(ctx), bitwidth));
} }
MlirType mlirIntegerTypeSignedGet(MlirContext ctx, unsigned bitwidth) { MlirType mlirIntegerTypeSignedGet(MlirContext ctx, unsigned bitwidth) {
return wrap(IntegerType::get(bitwidth, IntegerType::Signed, unwrap(ctx))); return wrap(IntegerType::get(unwrap(ctx), bitwidth, IntegerType::Signed));
} }
MlirType mlirIntegerTypeUnsignedGet(MlirContext ctx, unsigned bitwidth) { MlirType mlirIntegerTypeUnsignedGet(MlirContext ctx, unsigned bitwidth) {
return wrap(IntegerType::get(bitwidth, IntegerType::Unsigned, unwrap(ctx))); return wrap(IntegerType::get(unwrap(ctx), bitwidth, IntegerType::Unsigned));
} }
unsigned mlirIntegerTypeGetWidth(MlirType type) { unsigned mlirIntegerTypeGetWidth(MlirType type) {
@ -172,8 +172,8 @@ MlirType mlirVectorTypeGet(intptr_t rank, const int64_t *shape,
MlirType mlirVectorTypeGetChecked(intptr_t rank, const int64_t *shape, MlirType mlirVectorTypeGetChecked(intptr_t rank, const int64_t *shape,
MlirType elementType, MlirLocation loc) { MlirType elementType, MlirLocation loc) {
return wrap(VectorType::getChecked( return wrap(VectorType::getChecked(
llvm::makeArrayRef(shape, static_cast<size_t>(rank)), unwrap(elementType), unwrap(loc), llvm::makeArrayRef(shape, static_cast<size_t>(rank)),
unwrap(loc))); unwrap(elementType)));
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -201,8 +201,8 @@ MlirType mlirRankedTensorTypeGetChecked(intptr_t rank, const int64_t *shape,
MlirType elementType, MlirType elementType,
MlirLocation loc) { MlirLocation loc) {
return wrap(RankedTensorType::getChecked( return wrap(RankedTensorType::getChecked(
llvm::makeArrayRef(shape, static_cast<size_t>(rank)), unwrap(elementType), unwrap(loc), llvm::makeArrayRef(shape, static_cast<size_t>(rank)),
unwrap(loc))); unwrap(elementType)));
} }
MlirType mlirUnrankedTensorTypeGet(MlirType elementType) { MlirType mlirUnrankedTensorTypeGet(MlirType elementType) {
@ -211,7 +211,7 @@ MlirType mlirUnrankedTensorTypeGet(MlirType elementType) {
MlirType mlirUnrankedTensorTypeGetChecked(MlirType elementType, MlirType mlirUnrankedTensorTypeGetChecked(MlirType elementType,
MlirLocation loc) { MlirLocation loc) {
return wrap(UnrankedTensorType::getChecked(unwrap(elementType), unwrap(loc))); return wrap(UnrankedTensorType::getChecked(unwrap(loc), unwrap(elementType)));
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -244,8 +244,8 @@ MlirType mlirMemRefTypeContiguousGetChecked(MlirType elementType, intptr_t rank,
unsigned memorySpace, unsigned memorySpace,
MlirLocation loc) { MlirLocation loc) {
return wrap(MemRefType::getChecked( return wrap(MemRefType::getChecked(
llvm::makeArrayRef(shape, static_cast<size_t>(rank)), unwrap(elementType), unwrap(loc), llvm::makeArrayRef(shape, static_cast<size_t>(rank)),
llvm::None, memorySpace, unwrap(loc))); unwrap(elementType), llvm::None, memorySpace));
} }
intptr_t mlirMemRefTypeGetNumAffineMaps(MlirType type) { intptr_t mlirMemRefTypeGetNumAffineMaps(MlirType type) {
@ -272,8 +272,8 @@ MlirType mlirUnrankedMemRefTypeGet(MlirType elementType, unsigned memorySpace) {
MlirType mlirUnrankedMemRefTypeGetChecked(MlirType elementType, MlirType mlirUnrankedMemRefTypeGetChecked(MlirType elementType,
unsigned memorySpace, unsigned memorySpace,
MlirLocation loc) { MlirLocation loc) {
return wrap(UnrankedMemRefType::getChecked(unwrap(elementType), memorySpace, return wrap(UnrankedMemRefType::getChecked(unwrap(loc), unwrap(elementType),
unwrap(loc))); memorySpace));
} }
unsigned mlirUnrankedMemrefGetMemorySpace(MlirType type) { unsigned mlirUnrankedMemrefGetMemorySpace(MlirType type) {
@ -290,7 +290,7 @@ MlirType mlirTupleTypeGet(MlirContext ctx, intptr_t numElements,
MlirType const *elements) { MlirType const *elements) {
SmallVector<Type, 4> types; SmallVector<Type, 4> types;
ArrayRef<Type> typeRef = unwrapList(numElements, elements, types); ArrayRef<Type> typeRef = unwrapList(numElements, elements, types);
return wrap(TupleType::get(typeRef, unwrap(ctx))); return wrap(TupleType::get(unwrap(ctx), typeRef));
} }
intptr_t mlirTupleTypeGetNumTypes(MlirType type) { intptr_t mlirTupleTypeGetNumTypes(MlirType type) {
@ -316,7 +316,7 @@ MlirType mlirFunctionTypeGet(MlirContext ctx, intptr_t numInputs,
SmallVector<Type, 4> resultsList; SmallVector<Type, 4> resultsList;
(void)unwrapList(numInputs, inputs, inputsList); (void)unwrapList(numInputs, inputs, inputsList);
(void)unwrapList(numResults, results, resultsList); (void)unwrapList(numResults, results, resultsList);
return wrap(FunctionType::get(inputsList, resultsList, unwrap(ctx))); return wrap(FunctionType::get(unwrap(ctx), inputsList, resultsList));
} }
intptr_t mlirFunctionTypeGetNumInputs(MlirType type) { intptr_t mlirFunctionTypeGetNumInputs(MlirType type) {

View File

@ -53,52 +53,52 @@ namespace {
struct AsyncAPI { struct AsyncAPI {
static FunctionType addOrDropRefFunctionType(MLIRContext *ctx) { static FunctionType addOrDropRefFunctionType(MLIRContext *ctx) {
auto ref = LLVM::LLVMType::getInt8PtrTy(ctx); auto ref = LLVM::LLVMType::getInt8PtrTy(ctx);
auto count = IntegerType::get(32, ctx); auto count = IntegerType::get(ctx, 32);
return FunctionType::get({ref, count}, {}, ctx); return FunctionType::get(ctx, {ref, count}, {});
} }
static FunctionType createTokenFunctionType(MLIRContext *ctx) { static FunctionType createTokenFunctionType(MLIRContext *ctx) {
return FunctionType::get({}, {TokenType::get(ctx)}, ctx); return FunctionType::get(ctx, {}, {TokenType::get(ctx)});
} }
static FunctionType createGroupFunctionType(MLIRContext *ctx) { static FunctionType createGroupFunctionType(MLIRContext *ctx) {
return FunctionType::get({}, {GroupType::get(ctx)}, ctx); return FunctionType::get(ctx, {}, {GroupType::get(ctx)});
} }
static FunctionType emplaceTokenFunctionType(MLIRContext *ctx) { static FunctionType emplaceTokenFunctionType(MLIRContext *ctx) {
return FunctionType::get({TokenType::get(ctx)}, {}, ctx); return FunctionType::get(ctx, {TokenType::get(ctx)}, {});
} }
static FunctionType awaitTokenFunctionType(MLIRContext *ctx) { static FunctionType awaitTokenFunctionType(MLIRContext *ctx) {
return FunctionType::get({TokenType::get(ctx)}, {}, ctx); return FunctionType::get(ctx, {TokenType::get(ctx)}, {});
} }
static FunctionType awaitGroupFunctionType(MLIRContext *ctx) { static FunctionType awaitGroupFunctionType(MLIRContext *ctx) {
return FunctionType::get({GroupType::get(ctx)}, {}, ctx); return FunctionType::get(ctx, {GroupType::get(ctx)}, {});
} }
static FunctionType executeFunctionType(MLIRContext *ctx) { static FunctionType executeFunctionType(MLIRContext *ctx) {
auto hdl = LLVM::LLVMType::getInt8PtrTy(ctx); auto hdl = LLVM::LLVMType::getInt8PtrTy(ctx);
auto resume = resumeFunctionType(ctx).getPointerTo(); auto resume = resumeFunctionType(ctx).getPointerTo();
return FunctionType::get({hdl, resume}, {}, ctx); return FunctionType::get(ctx, {hdl, resume}, {});
} }
static FunctionType addTokenToGroupFunctionType(MLIRContext *ctx) { static FunctionType addTokenToGroupFunctionType(MLIRContext *ctx) {
auto i64 = IntegerType::get(64, ctx); auto i64 = IntegerType::get(ctx, 64);
return FunctionType::get({TokenType::get(ctx), GroupType::get(ctx)}, {i64}, return FunctionType::get(ctx, {TokenType::get(ctx), GroupType::get(ctx)},
ctx); {i64});
} }
static FunctionType awaitAndExecuteFunctionType(MLIRContext *ctx) { static FunctionType awaitAndExecuteFunctionType(MLIRContext *ctx) {
auto hdl = LLVM::LLVMType::getInt8PtrTy(ctx); auto hdl = LLVM::LLVMType::getInt8PtrTy(ctx);
auto resume = resumeFunctionType(ctx).getPointerTo(); auto resume = resumeFunctionType(ctx).getPointerTo();
return FunctionType::get({TokenType::get(ctx), hdl, resume}, {}, ctx); return FunctionType::get(ctx, {TokenType::get(ctx), hdl, resume}, {});
} }
static FunctionType awaitAllAndExecuteFunctionType(MLIRContext *ctx) { static FunctionType awaitAllAndExecuteFunctionType(MLIRContext *ctx) {
auto hdl = LLVM::LLVMType::getInt8PtrTy(ctx); auto hdl = LLVM::LLVMType::getInt8PtrTy(ctx);
auto resume = resumeFunctionType(ctx).getPointerTo(); auto resume = resumeFunctionType(ctx).getPointerTo();
return FunctionType::get({GroupType::get(ctx), hdl, resume}, {}, ctx); return FunctionType::get(ctx, {GroupType::get(ctx), hdl, resume}, {});
} }
// Auxiliary coroutine resume intrinsic wrapper. // Auxiliary coroutine resume intrinsic wrapper.
@ -690,7 +690,7 @@ public:
if (!addToGroup.operand().getType().isa<TokenType>()) if (!addToGroup.operand().getType().isa<TokenType>())
return failure(); return failure();
auto i64 = IntegerType::get(64, op->getContext()); auto i64 = IntegerType::get(op->getContext(), 64);
rewriter.replaceOpWithNewOp<CallOp>(op, kAddTokenToGroup, i64, operands); rewriter.replaceOpWithNewOp<CallOp>(op, kAddTokenToGroup, i64, operands);
return success(); return success();
} }

View File

@ -122,7 +122,7 @@ LogicalResult ConvertGpuLaunchFuncToVulkanLaunchFunc::declareVulkanLaunchFunc(
} }
// Declare vulkan launch function. // Declare vulkan launch function.
auto funcType = FunctionType::get(vulkanLaunchTypes, {}, loc->getContext()); auto funcType = builder.getFunctionType(vulkanLaunchTypes, {});
builder.create<FuncOp>(loc, kVulkanLaunch, funcType).setPrivate(); builder.create<FuncOp>(loc, kVulkanLaunch, funcType).setPrivate();
return success(); return success();

View File

@ -84,7 +84,7 @@ static LLVMType getPtrToElementType(T containerType,
/// }; /// };
static Type convertRangeType(RangeType t, LLVMTypeConverter &converter) { static Type convertRangeType(RangeType t, LLVMTypeConverter &converter) {
auto *context = t.getContext(); auto *context = t.getContext();
auto int64Ty = converter.convertType(IntegerType::get(64, context)) auto int64Ty = converter.convertType(IntegerType::get(context, 64))
.cast<LLVM::LLVMType>(); .cast<LLVM::LLVMType>();
return LLVMType::getStructTy(int64Ty, int64Ty, int64Ty); return LLVMType::getStructTy(int64Ty, int64Ty, int64Ty);
} }

View File

@ -65,7 +65,7 @@ static FlatSymbolRefAttr getLibraryCallSymbolRef(Operation *op,
assert(op->getNumResults() == 0 && assert(op->getNumResults() == 0 &&
"Library call for linalg operation can be generated only for ops that " "Library call for linalg operation can be generated only for ops that "
"have void return types"); "have void return types");
auto libFnType = FunctionType::get(inputTypes, {}, rewriter.getContext()); auto libFnType = rewriter.getFunctionType(inputTypes, {});
OpBuilder::InsertionGuard guard(rewriter); OpBuilder::InsertionGuard guard(rewriter);
// Insert before module terminator. // Insert before module terminator.

View File

@ -407,8 +407,7 @@ public:
// cover all possible corner cases. // cover all possible corner cases.
if (isSignedIntegerOrVector(srcType) || if (isSignedIntegerOrVector(srcType) ||
isUnsignedIntegerOrVector(srcType)) { isUnsignedIntegerOrVector(srcType)) {
auto *context = rewriter.getContext(); auto signlessType = rewriter.getIntegerType(getBitWidth(srcType));
auto signlessType = IntegerType::get(getBitWidth(srcType), context);
if (srcType.isa<VectorType>()) { if (srcType.isa<VectorType>()) {
auto dstElementsAttr = constOp.value().cast<DenseIntElementsAttr>(); auto dstElementsAttr = constOp.value().cast<DenseIntElementsAttr>();

View File

@ -584,7 +584,7 @@ LogicalResult VectorTransferRewriter<TransferReadOp>::matchAndRewrite(
std::swap(ivsStorage.back(), ivsStorage[coalescedIdx]); std::swap(ivsStorage.back(), ivsStorage[coalescedIdx]);
ArrayRef<Value> ivs(ivsStorage); ArrayRef<Value> ivs(ivsStorage);
Value pos = std_index_cast(IntegerType::get(32, ctx), ivs.back()); Value pos = std_index_cast(IntegerType::get(ctx, 32), ivs.back());
Value inVector = local(ivs.drop_back()); Value inVector = local(ivs.drop_back());
auto loadValue = [&](ArrayRef<Value> indices) { auto loadValue = [&](ArrayRef<Value> indices) {
Value vector = vector_insert_element(remote(indices), inVector, pos); Value vector = vector_insert_element(remote(indices), inVector, pos);
@ -671,7 +671,7 @@ LogicalResult VectorTransferRewriter<TransferWriteOp>::matchAndRewrite(
ArrayRef<Value> ivs(ivsStorage); ArrayRef<Value> ivs(ivsStorage);
Value pos = Value pos =
std_index_cast(IntegerType::get(32, op->getContext()), ivs.back()); std_index_cast(IntegerType::get(op->getContext(), 32), ivs.back());
auto storeValue = [&](ArrayRef<Value> indices) { auto storeValue = [&](ArrayRef<Value> indices) {
Value scalar = vector_extract_element(local(ivs.drop_back()), pos); Value scalar = vector_extract_element(local(ivs.drop_back()), pos);
remote(indices) = scalar; remote(indices) = scalar;

View File

@ -152,7 +152,7 @@ void ExecuteOp::build(OpBuilder &builder, OperationState &result,
int32_t numDependencies = dependencies.size(); int32_t numDependencies = dependencies.size();
int32_t numOperands = operands.size(); int32_t numOperands = operands.size();
auto operandSegmentSizes = DenseIntElementsAttr::get( auto operandSegmentSizes = DenseIntElementsAttr::get(
VectorType::get({2}, IntegerType::get(32, result.getContext())), VectorType::get({2}, builder.getIntegerType(32)),
{numDependencies, numOperands}); {numDependencies, numOperands});
result.addAttribute(kOperandSegmentSizesAttr, operandSegmentSizes); result.addAttribute(kOperandSegmentSizesAttr, operandSegmentSizes);

View File

@ -118,7 +118,7 @@ LogicalResult AsyncRefCountingPass::addAutomaticRefCounting(Value value) {
builder.setInsertionPointToStart(value.getParentBlock()); builder.setInsertionPointToStart(value.getParentBlock());
Location loc = value.getLoc(); Location loc = value.getLoc();
auto i32 = IntegerType::get(32, ctx); auto i32 = IntegerType::get(ctx, 32);
// Drop the reference count immediately if the value has no uses. // Drop the reference count immediately if the value has no uses.
if (value.getUses().empty()) { if (value.getUses().empty()) {

View File

@ -31,7 +31,7 @@ struct GpuAllReduceRewriter {
: funcOp(funcOp_), reduceOp(reduceOp_), rewriter(rewriter_), : funcOp(funcOp_), reduceOp(reduceOp_), rewriter(rewriter_),
loc(reduceOp.getLoc()), valueType(reduceOp.value().getType()), loc(reduceOp.getLoc()), valueType(reduceOp.value().getType()),
indexType(IndexType::get(reduceOp.getContext())), indexType(IndexType::get(reduceOp.getContext())),
int32Type(IntegerType::get(/*width=*/32, reduceOp.getContext())) {} int32Type(IntegerType::get(reduceOp.getContext(), /*width=*/32)) {}
/// Creates an all_reduce across the workgroup. /// Creates an all_reduce across the workgroup.
/// ///

View File

@ -155,7 +155,7 @@ static gpu::GPUFuncOp outlineKernelFuncImpl(gpu::LaunchOp launchOp,
kernelOperandTypes.push_back(operand.getType()); kernelOperandTypes.push_back(operand.getType());
} }
FunctionType type = FunctionType type =
FunctionType::get(kernelOperandTypes, {}, launchOp.getContext()); FunctionType::get(launchOp.getContext(), kernelOperandTypes, {});
auto outlinedFunc = builder.create<gpu::GPUFuncOp>(loc, kernelFnName, type); auto outlinedFunc = builder.create<gpu::GPUFuncOp>(loc, kernelFnName, type);
outlinedFunc->setAttr(gpu::GPUDialect::getKernelFuncAttrName(), outlinedFunc->setAttr(gpu::GPUDialect::getKernelFuncAttrName(),
builder.getUnitAttr()); builder.getUnitAttr());

View File

@ -120,8 +120,8 @@ static ParseResult parseCmpOp(OpAsmParser &parser, OperationState &result) {
static void printAllocaOp(OpAsmPrinter &p, AllocaOp &op) { static void printAllocaOp(OpAsmPrinter &p, AllocaOp &op) {
auto elemTy = op.getType().cast<LLVM::LLVMType>().getPointerElementTy(); auto elemTy = op.getType().cast<LLVM::LLVMType>().getPointerElementTy();
auto funcTy = FunctionType::get({op.arraySize().getType()}, {op.getType()}, auto funcTy = FunctionType::get(op.getContext(), {op.arraySize().getType()},
op.getContext()); {op.getType()});
p << op.getOperationName() << ' ' << op.arraySize() << " x " << elemTy; p << op.getOperationName() << ' ' << op.arraySize() << " x " << elemTy;
if (op.alignment().hasValue() && *op.alignment() != 0) if (op.alignment().hasValue() && *op.alignment() != 0)
@ -781,7 +781,7 @@ static void printCallOp(OpAsmPrinter &p, CallOp &op) {
// Reconstruct the function MLIR function type from operand and result types. // Reconstruct the function MLIR function type from operand and result types.
p << " : " p << " : "
<< FunctionType::get(args.getTypes(), op.getResultTypes(), op.getContext()); << FunctionType::get(op.getContext(), args.getTypes(), op.getResultTypes());
} }
// <operation> ::= `llvm.call` (function-id | ssa-use) `(` ssa-use-list `)` // <operation> ::= `llvm.call` (function-id | ssa-use) `(` ssa-use-list `)`

View File

@ -76,25 +76,25 @@ static Value allocBuffer(const LinalgPromotionOptions &options,
IntegerAttr alignment_attr; IntegerAttr alignment_attr;
if (alignment.hasValue()) if (alignment.hasValue())
alignment_attr = alignment_attr =
IntegerAttr::get(IntegerType::get(64, ctx), alignment.getValue()); IntegerAttr::get(IntegerType::get(ctx, 64), alignment.getValue());
if (!dynamicBuffers) if (!dynamicBuffers)
if (auto cst = size.getDefiningOp<ConstantIndexOp>()) if (auto cst = size.getDefiningOp<ConstantIndexOp>())
return options.useAlloca return options.useAlloca
? std_alloca(MemRefType::get(width * cst.getValue(), ? std_alloca(MemRefType::get(width * cst.getValue(),
IntegerType::get(8, ctx)), IntegerType::get(ctx, 8)),
ValueRange{}, alignment_attr) ValueRange{}, alignment_attr)
.value .value
: std_alloc(MemRefType::get(width * cst.getValue(), : std_alloc(MemRefType::get(width * cst.getValue(),
IntegerType::get(8, ctx)), IntegerType::get(ctx, 8)),
ValueRange{}, alignment_attr) ValueRange{}, alignment_attr)
.value; .value;
Value mul = Value mul =
folded_std_muli(folder, folded_std_constant_index(folder, width), size); folded_std_muli(folder, folded_std_constant_index(folder, width), size);
return options.useAlloca return options.useAlloca
? std_alloca(MemRefType::get(-1, IntegerType::get(8, ctx)), mul, ? std_alloca(MemRefType::get(-1, IntegerType::get(ctx, 8)), mul,
alignment_attr) alignment_attr)
.value .value
: std_alloc(MemRefType::get(-1, IntegerType::get(8, ctx)), mul, : std_alloc(MemRefType::get(-1, IntegerType::get(ctx, 8)), mul,
alignment_attr) alignment_attr)
.value; .value;
} }

View File

@ -18,7 +18,7 @@ static bool getDefaultStorageParams(unsigned numBits, bool narrowRange,
int64_t &qmax) { int64_t &qmax) {
// Hard-coded type mapping from TFLite. // Hard-coded type mapping from TFLite.
if (numBits <= 8) { if (numBits <= 8) {
storageType = IntegerType::get(8, ctx); storageType = IntegerType::get(ctx, 8);
if (isSigned) { if (isSigned) {
qmin = -128; qmin = -128;
qmax = 127; qmax = 127;
@ -27,7 +27,7 @@ static bool getDefaultStorageParams(unsigned numBits, bool narrowRange,
qmax = 255; qmax = 255;
} }
} else if (numBits <= 16) { } else if (numBits <= 16) {
storageType = IntegerType::get(16, ctx); storageType = IntegerType::get(ctx, 16);
if (isSigned) { if (isSigned) {
qmin = -32768; qmin = -32768;
qmax = 32767; qmax = 32767;
@ -36,7 +36,7 @@ static bool getDefaultStorageParams(unsigned numBits, bool narrowRange,
qmax = 65535; qmax = 65535;
} }
} else if (numBits <= 32) { } else if (numBits <= 32) {
storageType = IntegerType::get(32, ctx); storageType = IntegerType::get(ctx, 32);
if (isSigned) { if (isSigned) {
qmin = std::numeric_limits<int32_t>::min(); qmin = std::numeric_limits<int32_t>::min();
qmax = std::numeric_limits<int32_t>::max(); qmax = std::numeric_limits<int32_t>::max();

View File

@ -79,7 +79,7 @@ UniformQuantizedPerAxisValueConverter::convert(DenseFPElementsAttr attr) {
int64_t chunkSize = int64_t chunkSize =
std::accumulate(std::next(shape.begin(), quantizationDim + 1), std::accumulate(std::next(shape.begin(), quantizationDim + 1),
shape.end(), 1, std::multiplies<int64_t>()); shape.end(), 1, std::multiplies<int64_t>());
Type newElementType = IntegerType::get(storageBitWidth, attr.getContext()); Type newElementType = IntegerType::get(attr.getContext(), storageBitWidth);
return attr.mapValues(newElementType, [&](const APFloat &old) { return attr.mapValues(newElementType, [&](const APFloat &old) {
int chunkIndex = (flattenIndex++) / chunkSize; int chunkIndex = (flattenIndex++) / chunkSize;
return converters[chunkIndex % dimSize].quantizeFloatToInt(old); return converters[chunkIndex % dimSize].quantizeFloatToInt(old);

View File

@ -96,7 +96,7 @@ void mlir::outlineIfOp(OpBuilder &b, scf::IfOp ifOp, FuncOp *thenFn,
ValueRange values(captures.getArrayRef()); ValueRange values(captures.getArrayRef());
FunctionType type = FunctionType type =
FunctionType::get(values.getTypes(), ifOp.getResultTypes(), ctx); FunctionType::get(ctx, values.getTypes(), ifOp.getResultTypes());
auto outlinedFunc = b.create<FuncOp>(loc, funcName, type); auto outlinedFunc = b.create<FuncOp>(loc, funcName, type);
b.setInsertionPointToStart(outlinedFunc.addEntryBlock()); b.setInsertionPointToStart(outlinedFunc.addEntryBlock());
BlockAndValueMapping bvm; BlockAndValueMapping bvm;

View File

@ -123,7 +123,7 @@ spirv::getEntryPointABIAttr(ArrayRef<int32_t> localSize, MLIRContext *context) {
assert(localSize.size() == 3); assert(localSize.size() == 3);
return spirv::EntryPointABIAttr::get( return spirv::EntryPointABIAttr::get(
DenseElementsAttr::get<int32_t>( DenseElementsAttr::get<int32_t>(
VectorType::get(3, IntegerType::get(32, context)), localSize) VectorType::get(3, IntegerType::get(context, 32)), localSize)
.cast<DenseIntElementsAttr>(), .cast<DenseIntElementsAttr>(),
context); context);
} }

View File

@ -93,7 +93,7 @@ Type SPIRVTypeConverter::getIndexType(MLIRContext *context) {
// instructions. The Vulkan spec requires the builtins like // instructions. The Vulkan spec requires the builtins like
// GlobalInvocationID, etc. to be 32-bit (unsigned) integers which should be // GlobalInvocationID, etc. to be 32-bit (unsigned) integers which should be
// SExtended to 64-bit for index computations. // SExtended to 64-bit for index computations.
return IntegerType::get(32, context); return IntegerType::get(context, 32);
} }
/// Mapping between SPIR-V storage classes to memref memory spaces. /// Mapping between SPIR-V storage classes to memref memory spaces.
@ -260,8 +260,8 @@ convertScalarType(const spirv::TargetEnv &targetEnv, spirv::ScalarType type,
auto intType = type.cast<IntegerType>(); auto intType = type.cast<IntegerType>();
LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n"); LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n");
return IntegerType::get(/*width=*/32, intType.getSignedness(), return IntegerType::get(targetEnv.getContext(), /*width=*/32,
targetEnv.getContext()); intType.getSignedness());
} }
/// Converts a vector `type` to a suitable type under the given `targetEnv`. /// Converts a vector `type` to a suitable type under the given `targetEnv`.

View File

@ -714,7 +714,7 @@ LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
} }
FunctionType CallOp::getCalleeType() { FunctionType CallOp::getCalleeType() {
return FunctionType::get(getOperandTypes(), getResultTypes(), getContext()); return FunctionType::get(getContext(), getOperandTypes(), getResultTypes());
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -753,7 +753,7 @@ void CallIndirectOp::getCanonicalizationPatterns(
// Return the type of the same shape (scalar, vector or tensor) containing i1. // Return the type of the same shape (scalar, vector or tensor) containing i1.
static Type getI1SameShape(Type type) { static Type getI1SameShape(Type type) {
auto i1Type = IntegerType::get(1, type.getContext()); auto i1Type = IntegerType::get(type.getContext(), 1);
if (auto tensorType = type.dyn_cast<RankedTensorType>()) if (auto tensorType = type.dyn_cast<RankedTensorType>())
return RankedTensorType::get(tensorType.getShape(), i1Type); return RankedTensorType::get(tensorType.getShape(), i1Type);
if (type.isa<UnrankedTensorType>()) if (type.isa<UnrankedTensorType>())
@ -914,7 +914,7 @@ OpFoldResult CmpFOp::fold(ArrayRef<Attribute> operands) {
return {}; return {};
auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue()); auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
return IntegerAttr::get(IntegerType::get(1, getContext()), APInt(1, val)); return IntegerAttr::get(IntegerType::get(getContext(), 1), APInt(1, val));
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -1426,7 +1426,7 @@ static LogicalResult isSumOfIntegerArrayAttrConfinedToShape(
static ArrayAttr makeI64ArrayAttr(ArrayRef<int64_t> values, static ArrayAttr makeI64ArrayAttr(ArrayRef<int64_t> values,
MLIRContext *context) { MLIRContext *context) {
auto attrs = llvm::map_range(values, [context](int64_t v) -> Attribute { auto attrs = llvm::map_range(values, [context](int64_t v) -> Attribute {
return IntegerAttr::get(IntegerType::get(64, context), APInt(64, v)); return IntegerAttr::get(IntegerType::get(context, 64), APInt(64, v));
}); });
return ArrayAttr::get(llvm::to_vector<8>(attrs), context); return ArrayAttr::get(llvm::to_vector<8>(attrs), context);
} }
@ -2767,7 +2767,7 @@ static ParseResult parseTupleOp(OpAsmParser &parser, OperationState &result) {
parser.parseOptionalAttrDict(result.attributes) || parser.parseOptionalAttrDict(result.attributes) ||
parser.parseColonTypeList(types) || parser.parseColonTypeList(types) ||
parser.resolveOperands(operandInfos, types, loc, result.operands) || parser.resolveOperands(operandInfos, types, loc, result.operands) ||
parser.addTypeToList(TupleType::get(types, ctx), result.types)); parser.addTypeToList(TupleType::get(ctx, types), result.types));
} }
static void print(OpAsmPrinter &p, TupleOp op) { static void print(OpAsmPrinter &p, TupleOp op) {

View File

@ -215,7 +215,7 @@ static TupleType generateExtractSlicesOpResultType(VectorType vectorType,
// Create Vector type and add to 'vectorTypes[i]'. // Create Vector type and add to 'vectorTypes[i]'.
vectorTypes[i] = VectorType::get(sliceSizes, vectorType.getElementType()); vectorTypes[i] = VectorType::get(sliceSizes, vectorType.getElementType());
} }
return TupleType::get(vectorTypes, builder.getContext()); return builder.getTupleType(vectorTypes);
} }
// UnrolledVectorState aggregates per-operand/result vector state required for // UnrolledVectorState aggregates per-operand/result vector state required for

View File

@ -52,27 +52,27 @@ FloatType Builder::getF64Type() { return FloatType::getF64(context); }
IndexType Builder::getIndexType() { return IndexType::get(context); } IndexType Builder::getIndexType() { return IndexType::get(context); }
IntegerType Builder::getI1Type() { return IntegerType::get(1, context); } IntegerType Builder::getI1Type() { return IntegerType::get(context, 1); }
IntegerType Builder::getI32Type() { return IntegerType::get(32, context); } IntegerType Builder::getI32Type() { return IntegerType::get(context, 32); }
IntegerType Builder::getI64Type() { return IntegerType::get(64, context); } IntegerType Builder::getI64Type() { return IntegerType::get(context, 64); }
IntegerType Builder::getIntegerType(unsigned width) { IntegerType Builder::getIntegerType(unsigned width) {
return IntegerType::get(width, context); return IntegerType::get(context, width);
} }
IntegerType Builder::getIntegerType(unsigned width, bool isSigned) { IntegerType Builder::getIntegerType(unsigned width, bool isSigned) {
return IntegerType::get( return IntegerType::get(
width, isSigned ? IntegerType::Signed : IntegerType::Unsigned, context); context, width, isSigned ? IntegerType::Signed : IntegerType::Unsigned);
} }
FunctionType Builder::getFunctionType(TypeRange inputs, TypeRange results) { FunctionType Builder::getFunctionType(TypeRange inputs, TypeRange results) {
return FunctionType::get(inputs, results, context); return FunctionType::get(context, inputs, results);
} }
TupleType Builder::getTupleType(TypeRange elementTypes) { TupleType Builder::getTupleType(TypeRange elementTypes) {
return TupleType::get(elementTypes, context); return TupleType::get(context, elementTypes);
} }
NoneType Builder::getNoneType() { return NoneType::get(context); } NoneType Builder::getNoneType() { return NoneType::get(context); }

View File

@ -179,7 +179,7 @@ FuncOp FuncOp::clone(BlockAndValueMapping &mapper) {
for (unsigned i = 0, e = getNumArguments(); i != e; ++i) for (unsigned i = 0, e = getNumArguments(); i != e; ++i)
if (!mapper.contains(getArgument(i))) if (!mapper.contains(getArgument(i)))
inputTypes.push_back(newType.getInput(i)); inputTypes.push_back(newType.getInput(i));
newType = FunctionType::get(inputTypes, newType.getResults(), getContext()); newType = FunctionType::get(getContext(), inputTypes, newType.getResults());
} }
// Create the new function. // Create the new function.

View File

@ -35,7 +35,7 @@ ComplexType ComplexType::get(Type elementType) {
return Base::get(elementType.getContext(), elementType); return Base::get(elementType.getContext(), elementType);
} }
ComplexType ComplexType::getChecked(Type elementType, Location location) { ComplexType ComplexType::getChecked(Location location, Type elementType) {
return Base::getChecked(location, elementType); return Base::getChecked(location, elementType);
} }
@ -76,7 +76,7 @@ IntegerType::SignednessSemantics IntegerType::getSignedness() const {
IntegerType IntegerType::scaleElementBitwidth(unsigned scale) { IntegerType IntegerType::scaleElementBitwidth(unsigned scale) {
if (!scale) if (!scale)
return IntegerType(); return IntegerType();
return IntegerType::get(scale * getWidth(), getSignedness(), getContext()); return IntegerType::get(getContext(), scale * getWidth(), getSignedness());
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -126,8 +126,8 @@ FloatType FloatType::scaleElementBitwidth(unsigned scale) {
// FunctionType // FunctionType
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
FunctionType FunctionType::get(TypeRange inputs, TypeRange results, FunctionType FunctionType::get(MLIRContext *context, TypeRange inputs,
MLIRContext *context) { TypeRange results) {
return Base::get(context, inputs, results); return Base::get(context, inputs, results);
} }
@ -182,20 +182,20 @@ FunctionType::getWithoutArgsAndResults(ArrayRef<unsigned> argIndices,
newResultTypes = newResultTypesBuffer; newResultTypes = newResultTypesBuffer;
} }
return get(newInputTypes, newResultTypes, getContext()); return get(getContext(), newInputTypes, newResultTypes);
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// OpaqueType // OpaqueType
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpaqueType OpaqueType::get(Identifier dialect, StringRef typeData, OpaqueType OpaqueType::get(MLIRContext *context, Identifier dialect,
MLIRContext *context) { StringRef typeData) {
return Base::get(context, dialect, typeData); return Base::get(context, dialect, typeData);
} }
OpaqueType OpaqueType::getChecked(Identifier dialect, StringRef typeData, OpaqueType OpaqueType::getChecked(Location location, Identifier dialect,
MLIRContext *context, Location location) { StringRef typeData) {
return Base::getChecked(location, dialect, typeData); return Base::getChecked(location, dialect, typeData);
} }
@ -313,8 +313,8 @@ VectorType VectorType::get(ArrayRef<int64_t> shape, Type elementType) {
return Base::get(elementType.getContext(), shape, elementType); return Base::get(elementType.getContext(), shape, elementType);
} }
VectorType VectorType::getChecked(ArrayRef<int64_t> shape, Type elementType, VectorType VectorType::getChecked(Location location, ArrayRef<int64_t> shape,
Location location) { Type elementType) {
return Base::getChecked(location, shape, elementType); return Base::getChecked(location, shape, elementType);
} }
@ -379,9 +379,9 @@ RankedTensorType RankedTensorType::get(ArrayRef<int64_t> shape,
return Base::get(elementType.getContext(), shape, elementType); return Base::get(elementType.getContext(), shape, elementType);
} }
RankedTensorType RankedTensorType::getChecked(ArrayRef<int64_t> shape, RankedTensorType RankedTensorType::getChecked(Location location,
Type elementType, ArrayRef<int64_t> shape,
Location location) { Type elementType) {
return Base::getChecked(location, shape, elementType); return Base::getChecked(location, shape, elementType);
} }
@ -406,8 +406,8 @@ UnrankedTensorType UnrankedTensorType::get(Type elementType) {
return Base::get(elementType.getContext(), elementType); return Base::get(elementType.getContext(), elementType);
} }
UnrankedTensorType UnrankedTensorType::getChecked(Type elementType, UnrankedTensorType UnrankedTensorType::getChecked(Location location,
Location location) { Type elementType) {
return Base::getChecked(location, elementType); return Base::getChecked(location, elementType);
} }
@ -448,9 +448,10 @@ MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
/// UnknownLoc. If the MemRefType defined by the arguments would be /// UnknownLoc. If the MemRefType defined by the arguments would be
/// ill-formed, emits errors (to the handler registered with the context or to /// ill-formed, emits errors (to the handler registered with the context or to
/// the error stream) and returns nullptr. /// the error stream) and returns nullptr.
MemRefType MemRefType::getChecked(ArrayRef<int64_t> shape, Type elementType, MemRefType MemRefType::getChecked(Location location, ArrayRef<int64_t> shape,
Type elementType,
ArrayRef<AffineMap> affineMapComposition, ArrayRef<AffineMap> affineMapComposition,
unsigned memorySpace, Location location) { unsigned memorySpace) {
return getImpl(shape, elementType, affineMapComposition, memorySpace, return getImpl(shape, elementType, affineMapComposition, memorySpace,
location); location);
} }
@ -524,9 +525,9 @@ UnrankedMemRefType UnrankedMemRefType::get(Type elementType,
return Base::get(elementType.getContext(), elementType, memorySpace); return Base::get(elementType.getContext(), elementType, memorySpace);
} }
UnrankedMemRefType UnrankedMemRefType::getChecked(Type elementType, UnrankedMemRefType UnrankedMemRefType::getChecked(Location location,
unsigned memorySpace, Type elementType,
Location location) { unsigned memorySpace) {
return Base::getChecked(location, elementType, memorySpace); return Base::getChecked(location, elementType, memorySpace);
} }
@ -694,12 +695,12 @@ LogicalResult mlir::getStridesAndOffset(MemRefType t,
/// Get or create a new TupleType with the provided element types. Assumes the /// Get or create a new TupleType with the provided element types. Assumes the
/// arguments define a well-formed type. /// arguments define a well-formed type.
TupleType TupleType::get(TypeRange elementTypes, MLIRContext *context) { TupleType TupleType::get(MLIRContext *context, TypeRange elementTypes) {
return Base::get(context, elementTypes); return Base::get(context, elementTypes);
} }
/// Get or create an empty tuple type. /// Get or create an empty tuple type.
TupleType TupleType::get(MLIRContext *context) { return get({}, context); } TupleType TupleType::get(MLIRContext *context) { return get(context, {}); }
/// Return the elements types for this tuple. /// Return the elements types for this tuple.
ArrayRef<Type> TupleType::getTypes() const { return getImpl()->getTypes(); } ArrayRef<Type> TupleType::getTypes() const { return getImpl()->getTypes(); }

View File

@ -82,7 +82,7 @@ Type Dialect::parseType(DialectAsmParser &parser) const {
// If this dialect allows unknown types, then represent this with OpaqueType. // If this dialect allows unknown types, then represent this with OpaqueType.
if (allowsUnknownTypes()) { if (allowsUnknownTypes()) {
auto ns = Identifier::get(getNamespace(), getContext()); auto ns = Identifier::get(getNamespace(), getContext());
return OpaqueType::get(ns, parser.getFullSymbolSpec(), getContext()); return OpaqueType::get(getContext(), ns, parser.getFullSymbolSpec());
} }
parser.emitError(parser.getNameLoc()) parser.emitError(parser.getNameLoc())

View File

@ -772,25 +772,23 @@ getCachedIntegerType(unsigned width,
} }
} }
IntegerType IntegerType::get(unsigned width, MLIRContext *context) { IntegerType IntegerType::get(MLIRContext *context, unsigned width) {
return get(width, IntegerType::Signless, context); return get(context, width, IntegerType::Signless);
} }
IntegerType IntegerType::get(unsigned width, IntegerType IntegerType::get(MLIRContext *context, unsigned width,
IntegerType::SignednessSemantics signedness, IntegerType::SignednessSemantics signedness) {
MLIRContext *context) {
if (auto cached = getCachedIntegerType(width, signedness, context)) if (auto cached = getCachedIntegerType(width, signedness, context))
return cached; return cached;
return Base::get(context, width, signedness); return Base::get(context, width, signedness);
} }
IntegerType IntegerType::getChecked(unsigned width, Location location) { IntegerType IntegerType::getChecked(Location location, unsigned width) {
return getChecked(width, IntegerType::Signless, location); return getChecked(location, width, IntegerType::Signless);
} }
IntegerType IntegerType::getChecked(unsigned width, IntegerType IntegerType::getChecked(Location location, unsigned width,
SignednessSemantics signedness, SignednessSemantics signedness) {
Location location) {
if (auto cached = if (auto cached =
getCachedIntegerType(width, signedness, location->getContext())) getCachedIntegerType(width, signedness, location->getContext()))
return cached; return cached;

View File

@ -178,7 +178,7 @@ Operation::Operation(Location location, OperationName name,
if (hasSingleResult) if (hasSingleResult)
resultType = resultTypes.front(); resultType = resultTypes.front();
else else
resultType = TupleType::get(resultTypes, location->getContext()); resultType = TupleType::get(location->getContext(), resultTypes);
} }
} }

View File

@ -63,7 +63,7 @@ void Value::setType(Type newType) {
return; return;
auto newTypes = llvm::to_vector<4>(curTypes); auto newTypes = llvm::to_vector<4>(curTypes);
newTypes[resultNo] = newType; newTypes[resultNo] = newType;
owner->resultType = TupleType::get(newTypes, newType.getContext()); owner->resultType = TupleType::get(newType.getContext(), newTypes);
} }
/// If this value is the result of an Operation, return the operation that /// If this value is the result of an Operation, return the operation that

View File

@ -563,8 +563,8 @@ Type Parser::parseExtendedType() {
// Otherwise, form a new opaque type. // Otherwise, form a new opaque type.
return OpaqueType::getChecked( return OpaqueType::getChecked(
Identifier::get(dialectName, state.context), symbolData, getEncodedSourceLocation(loc),
state.context, getEncodedSourceLocation(loc)); Identifier::get(dialectName, state.context), symbolData);
}); });
} }

View File

@ -338,7 +338,7 @@ Type Parser::parseNonFunctionType() {
signSemantics = *signedness ? IntegerType::Signed : IntegerType::Unsigned; signSemantics = *signedness ? IntegerType::Signed : IntegerType::Unsigned;
consumeToken(Token::inttype); consumeToken(Token::inttype);
return IntegerType::get(width.getValue(), signSemantics, getContext()); return IntegerType::get(getContext(), width.getValue(), signSemantics);
} }
// float-type // float-type
@ -432,7 +432,7 @@ Type Parser::parseTupleType() {
parseToken(Token::greater, "expected '>' in tuple type")) parseToken(Token::greater, "expected '>' in tuple type"))
return nullptr; return nullptr;
return TupleType::get(types, getContext()); return TupleType::get(getContext(), types);
} }
/// Parse a vector type. /// Parse a vector type.

View File

@ -236,7 +236,7 @@ Type Importer::getStdTypeForAttr(LLVMType type) {
Attribute Importer::getConstantAsAttr(llvm::Constant *value) { Attribute Importer::getConstantAsAttr(llvm::Constant *value) {
if (auto *ci = dyn_cast<llvm::ConstantInt>(value)) if (auto *ci = dyn_cast<llvm::ConstantInt>(value))
return b.getIntegerAttr( return b.getIntegerAttr(
IntegerType::get(ci->getType()->getBitWidth(), context), IntegerType::get(context, ci->getType()->getBitWidth()),
ci->getValue()); ci->getValue());
if (auto *c = dyn_cast<llvm::ConstantDataArray>(value)) if (auto *c = dyn_cast<llvm::ConstantDataArray>(value))
if (c->isString()) if (c->isString())

View File

@ -1182,7 +1182,7 @@ LogicalResult Deserializer::processType(spirv::Opcode opcode,
// signless semantics for such cases. // signless semantics for such cases.
auto sign = operands[2] == 1 ? IntegerType::SignednessSemantics::Signed auto sign = operands[2] == 1 ? IntegerType::SignednessSemantics::Signed
: IntegerType::SignednessSemantics::Signless; : IntegerType::SignednessSemantics::Signless;
typeMap[operands[0]] = IntegerType::get(operands[1], sign, context); typeMap[operands[0]] = IntegerType::get(context, operands[1], sign);
} break; } break;
case spirv::Opcode::OpTypeFloat: { case spirv::Opcode::OpTypeFloat: {
if (operands.size() != 2) if (operands.size() != 2)
@ -1345,7 +1345,7 @@ LogicalResult Deserializer::processFunctionType(ArrayRef<uint32_t> operands) {
if (!isVoidType(returnType)) { if (!isVoidType(returnType)) {
returnTypes = llvm::makeArrayRef(returnType); returnTypes = llvm::makeArrayRef(returnType);
} }
typeMap[operands[0]] = FunctionType::get(argTypes, returnTypes, context); typeMap[operands[0]] = FunctionType::get(context, argTypes, returnTypes);
return success(); return success();
} }

View File

@ -1267,7 +1267,7 @@ LogicalResult Serializer::prepareBasicType(
} }
typeEnum = spirv::Opcode::OpTypeCooperativeMatrixNV; typeEnum = spirv::Opcode::OpTypeCooperativeMatrixNV;
auto getConstantOp = [&](uint32_t id) { auto getConstantOp = [&](uint32_t id) {
auto attr = IntegerAttr::get(IntegerType::get(32, type.getContext()), id); auto attr = IntegerAttr::get(IntegerType::get(type.getContext(), 32), id);
return prepareConstantInt(loc, attr); return prepareConstantInt(loc, attr);
}; };
operands.push_back(elementTypeID); operands.push_back(elementTypeID);

View File

@ -35,8 +35,8 @@ static void updateFuncOp(FuncOp func,
// Add the new arguments to the function type. // Add the new arguments to the function type.
auto newArgTypes = llvm::to_vector<6>( auto newArgTypes = llvm::to_vector<6>(
llvm::concat<const Type>(functionType.getInputs(), erasedResultTypes)); llvm::concat<const Type>(functionType.getInputs(), erasedResultTypes));
auto newFunctionType = FunctionType::get( auto newFunctionType = FunctionType::get(func.getContext(), newArgTypes,
newArgTypes, functionType.getResults(), func.getContext()); functionType.getResults());
func.setType(newFunctionType); func.setType(newFunctionType);
// Transfer the result attributes to arg attributes. // Transfer the result attributes to arg attributes.

View File

@ -230,9 +230,8 @@ void NormalizeMemRefs::updateFunctionSignature(FuncOp funcOp,
// We create a new function type and modify the function signature with this // We create a new function type and modify the function signature with this
// new type. // new type.
newFuncType = FunctionType::get(/*inputs=*/argTypes, newFuncType = FunctionType::get(&getContext(), /*inputs=*/argTypes,
/*results=*/resultTypes, /*results=*/resultTypes);
/*context=*/&getContext());
} }
// Since we update the function signature, it might affect the result types at // Since we update the function signature, it might affect the result types at
@ -463,9 +462,9 @@ void NormalizeMemRefs::normalizeFuncOpMemRefs(FuncOp funcOp,
continue; continue;
} }
FunctionType newFuncType = FunctionType::get(/*inputs=*/inputTypes, FunctionType newFuncType =
/*results=*/resultTypes, FunctionType::get(&getContext(), /*inputs=*/inputTypes,
/*context=*/&getContext()); /*results=*/resultTypes);
// Setting the new function signature for this external function. // Setting the new function signature for this external function.
funcOp.setType(newFuncType); funcOp.setType(newFuncType);
} }

View File

@ -2522,8 +2522,8 @@ struct FuncOpSignatureConversion : public OpConversionPattern<FuncOp> {
// Update the function signature in-place. // Update the function signature in-place.
rewriter.updateRootInPlace(funcOp, [&] { rewriter.updateRootInPlace(funcOp, [&] {
funcOp.setType(FunctionType::get(result.getConvertedTypes(), newResults, funcOp.setType(FunctionType::get(funcOp.getContext(),
funcOp.getContext())); result.getConvertedTypes(), newResults));
}); });
return success(); return success();
} }

View File

@ -56,7 +56,7 @@ static FuncOp makeFunction(StringRef name, ArrayRef<Type> results = {},
ArrayRef<Type> args = {}) { ArrayRef<Type> args = {}) {
auto &ctx = globalContext(); auto &ctx = globalContext();
auto function = FuncOp::create(UnknownLoc::get(&ctx), name, auto function = FuncOp::create(UnknownLoc::get(&ctx), name,
FunctionType::get(args, results, &ctx)); FunctionType::get(&ctx, args, results));
function.addEntryBlock(); function.addEntryBlock();
return function; return function;
} }
@ -277,7 +277,7 @@ TEST_FUNC(builder_blocks) {
TEST_FUNC(builder_cond_branch) { TEST_FUNC(builder_cond_branch) {
auto f = makeFunction("builder_cond_branch", {}, auto f = makeFunction("builder_cond_branch", {},
{IntegerType::get(1, &globalContext())}); {IntegerType::get(&globalContext(), 1)});
OpBuilder builder(f.getBody()); OpBuilder builder(f.getBody());
ScopedContext scope(builder, f.getLoc()); ScopedContext scope(builder, f.getLoc());
@ -390,8 +390,8 @@ TEST_FUNC(insertion_in_block) {
TEST_FUNC(zero_and_std_sign_extendi_op_i1_to_i8) { TEST_FUNC(zero_and_std_sign_extendi_op_i1_to_i8) {
using namespace edsc::op; using namespace edsc::op;
auto i1Type = IntegerType::get(1, &globalContext()); auto i1Type = IntegerType::get(&globalContext(), 1);
auto i8Type = IntegerType::get(8, &globalContext()); auto i8Type = IntegerType::get(&globalContext(), 8);
auto memrefType = MemRefType::get({}, i1Type, {}, 0); auto memrefType = MemRefType::get({}, i1Type, {}, 0);
auto f = makeFunction("zero_and_std_sign_extendi_op", {}, auto f = makeFunction("zero_and_std_sign_extendi_op", {},
{memrefType, memrefType}); {memrefType, memrefType});
@ -414,7 +414,7 @@ TEST_FUNC(zero_and_std_sign_extendi_op_i1_to_i8) {
} }
TEST_FUNC(operator_or) { TEST_FUNC(operator_or) {
auto i1Type = IntegerType::get(/*width=*/1, &globalContext()); auto i1Type = IntegerType::get(&globalContext(), /*width=*/1);
auto f = makeFunction("operator_or", {}, {i1Type, i1Type}); auto f = makeFunction("operator_or", {}, {i1Type, i1Type});
OpBuilder builder(f.getBody()); OpBuilder builder(f.getBody());
@ -435,7 +435,7 @@ TEST_FUNC(operator_or) {
} }
TEST_FUNC(operator_and) { TEST_FUNC(operator_and) {
auto i1Type = IntegerType::get(/*width=*/1, &globalContext()); auto i1Type = IntegerType::get(&globalContext(), /*width=*/1);
auto f = makeFunction("operator_and", {}, {i1Type, i1Type}); auto f = makeFunction("operator_and", {}, {i1Type, i1Type});
OpBuilder builder(f.getBody()); OpBuilder builder(f.getBody());
@ -536,7 +536,7 @@ TEST_FUNC(fptrunc_f32_bf16) {
TEST_FUNC(select_op_i32) { TEST_FUNC(select_op_i32) {
using namespace edsc::op; using namespace edsc::op;
auto i32Type = IntegerType::get(32, &globalContext()); auto i32Type = IntegerType::get(&globalContext(), 32);
auto memrefType = MemRefType::get( auto memrefType = MemRefType::get(
{ShapedType::kDynamicSize, ShapedType::kDynamicSize}, i32Type, {}, 0); {ShapedType::kDynamicSize, ShapedType::kDynamicSize}, i32Type, {}, 0);
auto f = makeFunction("select_op", {}, {memrefType}); auto f = makeFunction("select_op", {}, {memrefType});

View File

@ -653,7 +653,7 @@ LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents(
} }
int64_t dim = int64_t dim =
sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamicSize; sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamicSize;
auto type = IntegerType::get(17, context); auto type = IntegerType::get(context, 17);
inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type)); inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type));
return success(); return success();
} }

View File

@ -509,7 +509,7 @@ struct TestTypeConverter : public TypeConverter {
// Convert I42 to I43. // Convert I42 to I43.
if (t.isInteger(42)) { if (t.isInteger(42)) {
results.push_back(IntegerType::get(43, t.getContext())); results.push_back(IntegerType::get(t.getContext(), 43));
return success(); return success();
} }

View File

@ -69,9 +69,7 @@ struct TestDecomposeCallGraphTypes
Location loc) -> Optional<Value> { Location loc) -> Optional<Value> {
if (inputs.size() == 1) if (inputs.size() == 1)
return llvm::None; return llvm::None;
TypeRange TypeRange = inputs.getTypes(); TupleType tuple = builder.getTupleType(inputs.getTypes());
SmallVector<Type, 2> types(TypeRange.begin(), TypeRange.end());
TupleType tuple = TupleType::get(types, builder.getContext());
Value value = builder.create<test::MakeTupleOp>(loc, tuple, inputs); Value value = builder.create<test::MakeTupleOp>(loc, tuple, inputs);
return value; return value;
}); });

View File

@ -59,7 +59,7 @@ ElementsAttr getTestSparseElementsAttr(MLIRContext *ctx,
} else { } else {
tensorType = RankedTensorType::get(shape, eleType); tensorType = RankedTensorType::get(shape, eleType);
} }
auto indicesType = RankedTensorType::get({1, 2}, IntegerType::get(64, ctx)); auto indicesType = RankedTensorType::get({1, 2}, IntegerType::get(ctx, 64));
auto indices = auto indices =
DenseIntElementsAttr::get(indicesType, {APInt(64, 0), APInt(64, 0)}); DenseIntElementsAttr::get(indicesType, {APInt(64, 0), APInt(64, 0)});
auto valuesType = RankedTensorType::get({1}, eleType); auto valuesType = RankedTensorType::get({1}, eleType);
@ -77,7 +77,7 @@ UniformQuantizedType getTestQuantizedType(Type storageType, MLIRContext *ctx) {
TEST(QuantizationUtilsTest, convertFloatAttrUniform) { TEST(QuantizationUtilsTest, convertFloatAttrUniform) {
MLIRContext ctx; MLIRContext ctx;
ctx.getOrLoadDialect<QuantizationDialect>(); ctx.getOrLoadDialect<QuantizationDialect>();
IntegerType convertedType = IntegerType::get(8, &ctx); IntegerType convertedType = IntegerType::get(&ctx, 8);
auto quantizedType = getTestQuantizedType(convertedType, &ctx); auto quantizedType = getTestQuantizedType(convertedType, &ctx);
TestUniformQuantizedValueConverter converter(quantizedType); TestUniformQuantizedValueConverter converter(quantizedType);
@ -95,7 +95,7 @@ TEST(QuantizationUtilsTest, convertFloatAttrUniform) {
TEST(QuantizationUtilsTest, convertRankedDenseAttrUniform) { TEST(QuantizationUtilsTest, convertRankedDenseAttrUniform) {
MLIRContext ctx; MLIRContext ctx;
ctx.getOrLoadDialect<QuantizationDialect>(); ctx.getOrLoadDialect<QuantizationDialect>();
IntegerType convertedType = IntegerType::get(8, &ctx); IntegerType convertedType = IntegerType::get(&ctx, 8);
auto quantizedType = getTestQuantizedType(convertedType, &ctx); auto quantizedType = getTestQuantizedType(convertedType, &ctx);
TestUniformQuantizedValueConverter converter(quantizedType); TestUniformQuantizedValueConverter converter(quantizedType);
auto realValue = getTestElementsAttr<DenseElementsAttr, ArrayRef<Attribute>>( auto realValue = getTestElementsAttr<DenseElementsAttr, ArrayRef<Attribute>>(
@ -120,7 +120,7 @@ TEST(QuantizationUtilsTest, convertRankedDenseAttrUniform) {
TEST(QuantizationUtilsTest, convertRankedSplatAttrUniform) { TEST(QuantizationUtilsTest, convertRankedSplatAttrUniform) {
MLIRContext ctx; MLIRContext ctx;
ctx.getOrLoadDialect<QuantizationDialect>(); ctx.getOrLoadDialect<QuantizationDialect>();
IntegerType convertedType = IntegerType::get(8, &ctx); IntegerType convertedType = IntegerType::get(&ctx, 8);
auto quantizedType = getTestQuantizedType(convertedType, &ctx); auto quantizedType = getTestQuantizedType(convertedType, &ctx);
TestUniformQuantizedValueConverter converter(quantizedType); TestUniformQuantizedValueConverter converter(quantizedType);
auto realValue = getTestElementsAttr<DenseElementsAttr, Attribute>( auto realValue = getTestElementsAttr<DenseElementsAttr, Attribute>(
@ -145,7 +145,7 @@ TEST(QuantizationUtilsTest, convertRankedSplatAttrUniform) {
TEST(QuantizationUtilsTest, convertRankedSparseAttrUniform) { TEST(QuantizationUtilsTest, convertRankedSparseAttrUniform) {
MLIRContext ctx; MLIRContext ctx;
ctx.getOrLoadDialect<QuantizationDialect>(); ctx.getOrLoadDialect<QuantizationDialect>();
IntegerType convertedType = IntegerType::get(8, &ctx); IntegerType convertedType = IntegerType::get(&ctx, 8);
auto quantizedType = getTestQuantizedType(convertedType, &ctx); auto quantizedType = getTestQuantizedType(convertedType, &ctx);
TestUniformQuantizedValueConverter converter(quantizedType); TestUniformQuantizedValueConverter converter(quantizedType);
auto realValue = getTestSparseElementsAttr(&ctx, {1, 2}); auto realValue = getTestSparseElementsAttr(&ctx, {1, 2});

View File

@ -33,7 +33,7 @@ static void testSplat(Type eltType, const EltTy &splatElt) {
namespace { namespace {
TEST(DenseSplatTest, BoolSplat) { TEST(DenseSplatTest, BoolSplat) {
MLIRContext context; MLIRContext context;
IntegerType boolTy = IntegerType::get(1, &context); IntegerType boolTy = IntegerType::get(&context, 1);
RankedTensorType shape = RankedTensorType::get({2, 2}, boolTy); RankedTensorType shape = RankedTensorType::get({2, 2}, boolTy);
// Check that splat is automatically detected for boolean values. // Check that splat is automatically detected for boolean values.
@ -58,7 +58,7 @@ TEST(DenseSplatTest, LargeBoolSplat) {
constexpr int64_t boolCount = 56; constexpr int64_t boolCount = 56;
MLIRContext context; MLIRContext context;
IntegerType boolTy = IntegerType::get(1, &context); IntegerType boolTy = IntegerType::get(&context, 1);
RankedTensorType shape = RankedTensorType::get({boolCount}, boolTy); RankedTensorType shape = RankedTensorType::get({boolCount}, boolTy);
// Check that splat is automatically detected for boolean values. // Check that splat is automatically detected for boolean values.
@ -81,7 +81,7 @@ TEST(DenseSplatTest, LargeBoolSplat) {
TEST(DenseSplatTest, BoolNonSplat) { TEST(DenseSplatTest, BoolNonSplat) {
MLIRContext context; MLIRContext context;
IntegerType boolTy = IntegerType::get(1, &context); IntegerType boolTy = IntegerType::get(&context, 1);
RankedTensorType shape = RankedTensorType::get({6}, boolTy); RankedTensorType shape = RankedTensorType::get({6}, boolTy);
// Check that we properly handle non-splat values. // Check that we properly handle non-splat values.
@ -94,7 +94,7 @@ TEST(DenseSplatTest, OddIntSplat) {
// Test detecting a splat with an odd(non 8-bit) integer bitwidth. // Test detecting a splat with an odd(non 8-bit) integer bitwidth.
MLIRContext context; MLIRContext context;
constexpr size_t intWidth = 19; constexpr size_t intWidth = 19;
IntegerType intTy = IntegerType::get(intWidth, &context); IntegerType intTy = IntegerType::get(&context, intWidth);
APInt value(intWidth, 10); APInt value(intWidth, 10);
testSplat(intTy, value); testSplat(intTy, value);
@ -102,7 +102,7 @@ TEST(DenseSplatTest, OddIntSplat) {
TEST(DenseSplatTest, Int32Splat) { TEST(DenseSplatTest, Int32Splat) {
MLIRContext context; MLIRContext context;
IntegerType intTy = IntegerType::get(32, &context); IntegerType intTy = IntegerType::get(&context, 32);
int value = 64; int value = 64;
testSplat(intTy, value); testSplat(intTy, value);
@ -110,7 +110,7 @@ TEST(DenseSplatTest, Int32Splat) {
TEST(DenseSplatTest, IntAttrSplat) { TEST(DenseSplatTest, IntAttrSplat) {
MLIRContext context; MLIRContext context;
IntegerType intTy = IntegerType::get(85, &context); IntegerType intTy = IntegerType::get(&context, 85);
Attribute value = IntegerAttr::get(intTy, 109); Attribute value = IntegerAttr::get(intTy, 109);
testSplat(intTy, value); testSplat(intTy, value);
@ -151,7 +151,7 @@ TEST(DenseSplatTest, BF16Splat) {
TEST(DenseSplatTest, StringSplat) { TEST(DenseSplatTest, StringSplat) {
MLIRContext context; MLIRContext context;
Type stringType = Type stringType =
OpaqueType::get(Identifier::get("test", &context), "string", &context); OpaqueType::get(&context, Identifier::get("test", &context), "string");
StringRef value = "test-string"; StringRef value = "test-string";
testSplat(stringType, value); testSplat(stringType, value);
} }
@ -159,7 +159,7 @@ TEST(DenseSplatTest, StringSplat) {
TEST(DenseSplatTest, StringAttrSplat) { TEST(DenseSplatTest, StringAttrSplat) {
MLIRContext context; MLIRContext context;
Type stringType = Type stringType =
OpaqueType::get(Identifier::get("test", &context), "string", &context); OpaqueType::get(&context, Identifier::get("test", &context), "string");
Attribute stringAttr = StringAttr::get("test-string", stringType); Attribute stringAttr = StringAttr::get("test-string", stringType);
testSplat(stringType, stringAttr); testSplat(stringType, stringAttr);
} }
@ -173,7 +173,7 @@ TEST(DenseComplexTest, ComplexFloatSplat) {
TEST(DenseComplexTest, ComplexIntSplat) { TEST(DenseComplexTest, ComplexIntSplat) {
MLIRContext context; MLIRContext context;
ComplexType complexType = ComplexType::get(IntegerType::get(64, &context)); ComplexType complexType = ComplexType::get(IntegerType::get(&context, 64));
std::complex<int64_t> value(10, 15); std::complex<int64_t> value(10, 15);
testSplat(complexType, value); testSplat(complexType, value);
} }
@ -187,7 +187,7 @@ TEST(DenseComplexTest, ComplexAPFloatSplat) {
TEST(DenseComplexTest, ComplexAPIntSplat) { TEST(DenseComplexTest, ComplexAPIntSplat) {
MLIRContext context; MLIRContext context;
ComplexType complexType = ComplexType::get(IntegerType::get(64, &context)); ComplexType complexType = ComplexType::get(IntegerType::get(&context, 64));
std::complex<APInt> value(APInt(64, 10), APInt(64, 15)); std::complex<APInt> value(APInt(64, 10), APInt(64, 15));
testSplat(complexType, value); testSplat(complexType, value);
} }

View File

@ -25,7 +25,7 @@ namespace mlir {
/// Helper that returns an example test::TestStruct for testing its /// Helper that returns an example test::TestStruct for testing its
/// implementation. /// implementation.
static test::TestStruct getTestStruct(mlir::MLIRContext *context) { static test::TestStruct getTestStruct(mlir::MLIRContext *context) {
auto integerType = mlir::IntegerType::get(32, context); auto integerType = mlir::IntegerType::get(context, 32);
auto integerAttr = mlir::IntegerAttr::get(integerType, 127); auto integerAttr = mlir::IntegerAttr::get(integerType, 127);
auto floatType = mlir::FloatType::getF32(context); auto floatType = mlir::FloatType::getF32(context);
@ -105,7 +105,7 @@ TEST(StructsGenTest, ClassofBadTypeFalse) {
expectedValues.begin(), expectedValues.end() - 1); expectedValues.begin(), expectedValues.end() - 1);
// Add a copy of the last attribute with the wrong type. // Add a copy of the last attribute with the wrong type.
auto i64Type = mlir::IntegerType::get(64, &context); auto i64Type = mlir::IntegerType::get(&context, 64);
auto elementsType = mlir::RankedTensorType::get({3}, i64Type); auto elementsType = mlir::RankedTensorType::get({3}, i64Type);
auto elementsAttr = auto elementsAttr =
mlir::DenseIntElementsAttr::get(elementsType, ArrayRef<int64_t>{1, 2, 3}); mlir::DenseIntElementsAttr::get(elementsType, ArrayRef<int64_t>{1, 2, 3});