[mlir][IR][NFC] Move context/location parameters of builtin Type::get methods to the start of the parameter list
This better matches the rest of the infrastructure, is much simpler, and makes it easier to move these types to being declaratively specified. Differential Revision: https://reviews.llvm.org/D93432
This commit is contained in:
parent
511cfe9441
commit
1b97cdf885
|
@ -2176,7 +2176,7 @@ def fir_DispatchOp : fir_Op<"dispatch",
|
||||||
p.printOptionalAttrDict(getAttrs(), {"fn_type", "method"});
|
p.printOptionalAttrDict(getAttrs(), {"fn_type", "method"});
|
||||||
auto resTy{getResultTypes()};
|
auto resTy{getResultTypes()};
|
||||||
llvm::SmallVector<mlir::Type, 8> argTy(getOperandTypes());
|
llvm::SmallVector<mlir::Type, 8> argTy(getOperandTypes());
|
||||||
p << " : " << mlir::FunctionType::get(argTy, resTy, getContext());
|
p << " : " << mlir::FunctionType::get(getContext(), argTy, resTy);
|
||||||
}];
|
}];
|
||||||
|
|
||||||
let extraClassDeclaration = [{
|
let extraClassDeclaration = [{
|
||||||
|
|
|
@ -49,7 +49,7 @@ mlir::Type genFIRType(mlir::MLIRContext *context) {
|
||||||
if constexpr (TC == Fortran::common::TypeCategory::Integer) {
|
if constexpr (TC == Fortran::common::TypeCategory::Integer) {
|
||||||
auto bits{Fortran::evaluate::Type<Fortran::common::TypeCategory::Integer,
|
auto bits{Fortran::evaluate::Type<Fortran::common::TypeCategory::Integer,
|
||||||
KIND>::Scalar::bits};
|
KIND>::Scalar::bits};
|
||||||
return mlir::IntegerType::get(bits, context);
|
return mlir::IntegerType::get(context, bits);
|
||||||
} else if constexpr (TC == Fortran::common::TypeCategory::Logical ||
|
} else if constexpr (TC == Fortran::common::TypeCategory::Logical ||
|
||||||
TC == Fortran::common::TypeCategory::Character ||
|
TC == Fortran::common::TypeCategory::Character ||
|
||||||
TC == Fortran::common::TypeCategory::Complex) {
|
TC == Fortran::common::TypeCategory::Complex) {
|
||||||
|
@ -278,7 +278,7 @@ private:
|
||||||
|
|
||||||
// some sequence of `n` bytes
|
// some sequence of `n` bytes
|
||||||
mlir::Type gen(const Fortran::evaluate::StaticDataObject::Pointer &ptr) {
|
mlir::Type gen(const Fortran::evaluate::StaticDataObject::Pointer &ptr) {
|
||||||
mlir::Type byteTy{mlir::IntegerType::get(8, context)};
|
mlir::Type byteTy{mlir::IntegerType::get(context, 8)};
|
||||||
return fir::SequenceType::get(trivialShape(ptr->itemBytes()), byteTy);
|
return fir::SequenceType::get(trivialShape(ptr->itemBytes()), byteTy);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -298,26 +298,26 @@ static constexpr RuntimeFunction pgmathPrecise[] = {
|
||||||
|
|
||||||
static mlir::FunctionType genF32F32FuncType(mlir::MLIRContext *context) {
|
static mlir::FunctionType genF32F32FuncType(mlir::MLIRContext *context) {
|
||||||
auto t = mlir::FloatType::getF32(context);
|
auto t = mlir::FloatType::getF32(context);
|
||||||
return mlir::FunctionType::get({t}, {t}, context);
|
return mlir::FunctionType::get(context, {t}, {t});
|
||||||
}
|
}
|
||||||
|
|
||||||
static mlir::FunctionType genF64F64FuncType(mlir::MLIRContext *context) {
|
static mlir::FunctionType genF64F64FuncType(mlir::MLIRContext *context) {
|
||||||
auto t = mlir::FloatType::getF64(context);
|
auto t = mlir::FloatType::getF64(context);
|
||||||
return mlir::FunctionType::get({t}, {t}, context);
|
return mlir::FunctionType::get(context, {t}, {t});
|
||||||
}
|
}
|
||||||
|
|
||||||
template <int Bits>
|
template <int Bits>
|
||||||
static mlir::FunctionType genIntF64FuncType(mlir::MLIRContext *context) {
|
static mlir::FunctionType genIntF64FuncType(mlir::MLIRContext *context) {
|
||||||
auto t = mlir::FloatType::getF64(context);
|
auto t = mlir::FloatType::getF64(context);
|
||||||
auto r = mlir::IntegerType::get(Bits, context);
|
auto r = mlir::IntegerType::get(context, Bits);
|
||||||
return mlir::FunctionType::get({t}, {r}, context);
|
return mlir::FunctionType::get(context, {t}, {r});
|
||||||
}
|
}
|
||||||
|
|
||||||
template <int Bits>
|
template <int Bits>
|
||||||
static mlir::FunctionType genIntF32FuncType(mlir::MLIRContext *context) {
|
static mlir::FunctionType genIntF32FuncType(mlir::MLIRContext *context) {
|
||||||
auto t = mlir::FloatType::getF32(context);
|
auto t = mlir::FloatType::getF32(context);
|
||||||
auto r = mlir::IntegerType::get(Bits, context);
|
auto r = mlir::IntegerType::get(context, Bits);
|
||||||
return mlir::FunctionType::get({t}, {r}, context);
|
return mlir::FunctionType::get(context, {t}, {r});
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO : Fill-up this table with more intrinsic.
|
// TODO : Fill-up this table with more intrinsic.
|
||||||
|
@ -585,8 +585,8 @@ getFunctionType(mlir::Type resultType, llvm::ArrayRef<mlir::Value> arguments,
|
||||||
llvm::SmallVector<mlir::Type, 2> argumentTypes;
|
llvm::SmallVector<mlir::Type, 2> argumentTypes;
|
||||||
for (auto &arg : arguments)
|
for (auto &arg : arguments)
|
||||||
argumentTypes.push_back(arg.getType());
|
argumentTypes.push_back(arg.getType());
|
||||||
return mlir::FunctionType::get(argumentTypes, resultType,
|
return mlir::FunctionType::get(builder.getModule().getContext(),
|
||||||
builder.getModule().getContext());
|
argumentTypes, resultType);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// fir::ExtendedValue to mlir::Value translation layer
|
/// fir::ExtendedValue to mlir::Value translation layer
|
||||||
|
@ -1144,7 +1144,7 @@ mlir::Value IntrinsicLibrary::genMerge(mlir::Type,
|
||||||
llvm::ArrayRef<mlir::Value> args) {
|
llvm::ArrayRef<mlir::Value> args) {
|
||||||
assert(args.size() == 3);
|
assert(args.size() == 3);
|
||||||
|
|
||||||
auto i1Type = mlir::IntegerType::get(1, builder.getContext());
|
auto i1Type = mlir::IntegerType::get(builder.getContext(), 1);
|
||||||
auto mask = builder.createConvert(loc, i1Type, args[2]);
|
auto mask = builder.createConvert(loc, i1Type, args[2]);
|
||||||
return builder.create<mlir::SelectOp>(loc, mask, args[0], args[1]);
|
return builder.create<mlir::SelectOp>(loc, mask, args[0], args[1]);
|
||||||
}
|
}
|
||||||
|
|
|
@ -48,7 +48,7 @@ static constexpr TypeBuilderFunc getModel();
|
||||||
template <>
|
template <>
|
||||||
constexpr TypeBuilderFunc getModel<int>() {
|
constexpr TypeBuilderFunc getModel<int>() {
|
||||||
return [](mlir::MLIRContext *context) -> mlir::Type {
|
return [](mlir::MLIRContext *context) -> mlir::Type {
|
||||||
return mlir::IntegerType::get(8 * sizeof(int), context);
|
return mlir::IntegerType::get(context, 8 * sizeof(int));
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
template <>
|
template <>
|
||||||
|
@ -61,14 +61,14 @@ constexpr TypeBuilderFunc getModel<int &>() {
|
||||||
template <>
|
template <>
|
||||||
constexpr TypeBuilderFunc getModel<Fortran::runtime::io::Iostat>() {
|
constexpr TypeBuilderFunc getModel<Fortran::runtime::io::Iostat>() {
|
||||||
return [](mlir::MLIRContext *context) -> mlir::Type {
|
return [](mlir::MLIRContext *context) -> mlir::Type {
|
||||||
return mlir::IntegerType::get(8 * sizeof(Fortran::runtime::io::Iostat),
|
return mlir::IntegerType::get(context,
|
||||||
context);
|
8 * sizeof(Fortran::runtime::io::Iostat));
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
template <>
|
template <>
|
||||||
constexpr TypeBuilderFunc getModel<char *>() {
|
constexpr TypeBuilderFunc getModel<char *>() {
|
||||||
return [](mlir::MLIRContext *context) -> mlir::Type {
|
return [](mlir::MLIRContext *context) -> mlir::Type {
|
||||||
return fir::ReferenceType::get(mlir::IntegerType::get(8, context));
|
return fir::ReferenceType::get(mlir::IntegerType::get(context, 8));
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
template <>
|
template <>
|
||||||
|
@ -78,26 +78,26 @@ constexpr TypeBuilderFunc getModel<const char *>() {
|
||||||
template <>
|
template <>
|
||||||
constexpr TypeBuilderFunc getModel<const char16_t *>() {
|
constexpr TypeBuilderFunc getModel<const char16_t *>() {
|
||||||
return [](mlir::MLIRContext *context) -> mlir::Type {
|
return [](mlir::MLIRContext *context) -> mlir::Type {
|
||||||
return fir::ReferenceType::get(mlir::IntegerType::get(16, context));
|
return fir::ReferenceType::get(mlir::IntegerType::get(context, 16));
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
template <>
|
template <>
|
||||||
constexpr TypeBuilderFunc getModel<const char32_t *>() {
|
constexpr TypeBuilderFunc getModel<const char32_t *>() {
|
||||||
return [](mlir::MLIRContext *context) -> mlir::Type {
|
return [](mlir::MLIRContext *context) -> mlir::Type {
|
||||||
return fir::ReferenceType::get(mlir::IntegerType::get(32, context));
|
return fir::ReferenceType::get(mlir::IntegerType::get(context, 32));
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
template <>
|
template <>
|
||||||
constexpr TypeBuilderFunc getModel<void **>() {
|
constexpr TypeBuilderFunc getModel<void **>() {
|
||||||
return [](mlir::MLIRContext *context) -> mlir::Type {
|
return [](mlir::MLIRContext *context) -> mlir::Type {
|
||||||
return fir::ReferenceType::get(
|
return fir::ReferenceType::get(
|
||||||
fir::PointerType::get(mlir::IntegerType::get(8, context)));
|
fir::PointerType::get(mlir::IntegerType::get(context, 8)));
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
template <>
|
template <>
|
||||||
constexpr TypeBuilderFunc getModel<std::int64_t>() {
|
constexpr TypeBuilderFunc getModel<std::int64_t>() {
|
||||||
return [](mlir::MLIRContext *context) -> mlir::Type {
|
return [](mlir::MLIRContext *context) -> mlir::Type {
|
||||||
return mlir::IntegerType::get(64, context);
|
return mlir::IntegerType::get(context, 64);
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
template <>
|
template <>
|
||||||
|
@ -110,7 +110,7 @@ constexpr TypeBuilderFunc getModel<std::int64_t &>() {
|
||||||
template <>
|
template <>
|
||||||
constexpr TypeBuilderFunc getModel<std::size_t>() {
|
constexpr TypeBuilderFunc getModel<std::size_t>() {
|
||||||
return [](mlir::MLIRContext *context) -> mlir::Type {
|
return [](mlir::MLIRContext *context) -> mlir::Type {
|
||||||
return mlir::IntegerType::get(8 * sizeof(std::size_t), context);
|
return mlir::IntegerType::get(context, 8 * sizeof(std::size_t));
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
template <>
|
template <>
|
||||||
|
@ -146,7 +146,7 @@ constexpr TypeBuilderFunc getModel<float &>() {
|
||||||
template <>
|
template <>
|
||||||
constexpr TypeBuilderFunc getModel<bool>() {
|
constexpr TypeBuilderFunc getModel<bool>() {
|
||||||
return [](mlir::MLIRContext *context) -> mlir::Type {
|
return [](mlir::MLIRContext *context) -> mlir::Type {
|
||||||
return mlir::IntegerType::get(1, context);
|
return mlir::IntegerType::get(context, 1);
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
template <>
|
template <>
|
||||||
|
@ -190,7 +190,7 @@ struct RuntimeTableKey<RT(ATs...)> {
|
||||||
llvm::SmallVector<mlir::Type, sizeof...(ATs)> argTys;
|
llvm::SmallVector<mlir::Type, sizeof...(ATs)> argTys;
|
||||||
for (auto f : args)
|
for (auto f : args)
|
||||||
argTys.push_back(f(ctxt));
|
argTys.push_back(f(ctxt));
|
||||||
return mlir::FunctionType::get(argTys, {retTy}, ctxt);
|
return mlir::FunctionType::get(ctxt, argTys, {retTy});
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
|
@ -151,7 +151,7 @@ mlir::Type fir::BoxDimsOp::getTupleType() {
|
||||||
// note: triple, but 4 is nearest power of 2
|
// note: triple, but 4 is nearest power of 2
|
||||||
llvm::SmallVector<mlir::Type, 4> triple{
|
llvm::SmallVector<mlir::Type, 4> triple{
|
||||||
getResult(0).getType(), getResult(1).getType(), getResult(2).getType()};
|
getResult(0).getType(), getResult(1).getType(), getResult(2).getType()};
|
||||||
return mlir::TupleType::get(triple, getContext());
|
return mlir::TupleType::get(getContext(), triple);
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -171,7 +171,7 @@ static void printCallOp(mlir::OpAsmPrinter &p, fir::CallOp &op) {
|
||||||
auto resultTypes{op.getResultTypes()};
|
auto resultTypes{op.getResultTypes()};
|
||||||
llvm::SmallVector<Type, 8> argTypes(
|
llvm::SmallVector<Type, 8> argTypes(
|
||||||
llvm::drop_begin(op.getOperandTypes(), isDirect ? 0 : 1));
|
llvm::drop_begin(op.getOperandTypes(), isDirect ? 0 : 1));
|
||||||
p << " : " << FunctionType::get(argTypes, resultTypes, op.getContext());
|
p << " : " << FunctionType::get(op.getContext(), argTypes, resultTypes);
|
||||||
}
|
}
|
||||||
|
|
||||||
static mlir::ParseResult parseCallOp(mlir::OpAsmParser &parser,
|
static mlir::ParseResult parseCallOp(mlir::OpAsmParser &parser,
|
||||||
|
@ -1565,4 +1565,3 @@ fir::GlobalOp fir::createGlobalOp(mlir::Location loc, mlir::ModuleOp module,
|
||||||
|
|
||||||
#define GET_OP_CLASSES
|
#define GET_OP_CLASSES
|
||||||
#include "flang/Optimizer/Dialect/FIROps.cpp.inc"
|
#include "flang/Optimizer/Dialect/FIROps.cpp.inc"
|
||||||
|
|
||||||
|
|
|
@ -35,8 +35,8 @@ def MaskRndScaleOp : AVX512_Op<"mask.rndscale", [NoSideEffect,
|
||||||
AllTypesMatch<["src", "a", "dst"]>,
|
AllTypesMatch<["src", "a", "dst"]>,
|
||||||
TypesMatchWith<"imm has the same number of bits as elements in dst",
|
TypesMatchWith<"imm has the same number of bits as elements in dst",
|
||||||
"dst", "imm",
|
"dst", "imm",
|
||||||
"IntegerType::get(($_self.cast<VectorType>().getShape()[0]),"
|
"IntegerType::get($_self.getContext(), "
|
||||||
" $_self.getContext())">]> {
|
"($_self.cast<VectorType>().getShape()[0]))">]> {
|
||||||
let summary = "Masked roundscale op";
|
let summary = "Masked roundscale op";
|
||||||
let description = [{
|
let description = [{
|
||||||
The mask.rndscale op is an AVX512 specific op that can lower to the proper
|
The mask.rndscale op is an AVX512 specific op that can lower to the proper
|
||||||
|
@ -67,8 +67,8 @@ def MaskScaleFOp : AVX512_Op<"mask.scalef", [NoSideEffect,
|
||||||
AllTypesMatch<["src", "a", "b", "dst"]>,
|
AllTypesMatch<["src", "a", "b", "dst"]>,
|
||||||
TypesMatchWith<"k has the same number of bits as elements in dst",
|
TypesMatchWith<"k has the same number of bits as elements in dst",
|
||||||
"dst", "k",
|
"dst", "k",
|
||||||
"IntegerType::get(($_self.cast<VectorType>().getShape()[0]),"
|
"IntegerType::get($_self.getContext(), "
|
||||||
" $_self.getContext())">]> {
|
"($_self.cast<VectorType>().getShape()[0]))">]> {
|
||||||
let summary = "ScaleF op";
|
let summary = "ScaleF op";
|
||||||
let description = [{
|
let description = [{
|
||||||
The `mask.scalef` op is an AVX512 specific op that can lower to the proper
|
The `mask.scalef` op is an AVX512 specific op that can lower to the proper
|
||||||
|
|
|
@ -911,7 +911,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
|
||||||
auto attr = (*this)->getAttr("operand_segment_sizes")
|
auto attr = (*this)->getAttr("operand_segment_sizes")
|
||||||
.cast<DenseIntElementsAttr>();
|
.cast<DenseIntElementsAttr>();
|
||||||
unsigned i = 0;
|
unsigned i = 0;
|
||||||
auto newAttr = attr.mapValues(IntegerType::get(32, getContext()),
|
auto newAttr = attr.mapValues(IntegerType::get(getContext(), 32),
|
||||||
[&](const APInt &v) { return (i++ == idx) ? APInt(32, val) : v; });
|
[&](const APInt &v) { return (i++ == idx) ? APInt(32, val) : v; });
|
||||||
getOperation()->setAttr("operand_segment_sizes", newAttr);
|
getOperation()->setAttr("operand_segment_sizes", newAttr);
|
||||||
}
|
}
|
||||||
|
|
|
@ -63,7 +63,7 @@ public:
|
||||||
/// Get or create a ComplexType with the provided element type. This emits
|
/// Get or create a ComplexType with the provided element type. This emits
|
||||||
/// and error at the specified location and returns null if the element type
|
/// and error at the specified location and returns null if the element type
|
||||||
/// isn't supported.
|
/// isn't supported.
|
||||||
static ComplexType getChecked(Type elementType, Location location);
|
static ComplexType getChecked(Location location, Type elementType);
|
||||||
|
|
||||||
/// Verify the construction of an integer type.
|
/// Verify the construction of an integer type.
|
||||||
static LogicalResult verifyConstructionInvariants(Location loc,
|
static LogicalResult verifyConstructionInvariants(Location loc,
|
||||||
|
@ -93,27 +93,27 @@ public:
|
||||||
/// The created IntegerType is signless (i.e., no signedness semantics).
|
/// The created IntegerType is signless (i.e., no signedness semantics).
|
||||||
/// Assume the width is within the allowed range and assert on failures. Use
|
/// Assume the width is within the allowed range and assert on failures. Use
|
||||||
/// getChecked to handle failures gracefully.
|
/// getChecked to handle failures gracefully.
|
||||||
static IntegerType get(unsigned width, MLIRContext *context);
|
static IntegerType get(MLIRContext *context, unsigned width);
|
||||||
|
|
||||||
/// Get or create a new IntegerType of the given width within the context.
|
/// Get or create a new IntegerType of the given width within the context.
|
||||||
/// The created IntegerType has signedness semantics as indicated via
|
/// The created IntegerType has signedness semantics as indicated via
|
||||||
/// `signedness`. Assume the width is within the allowed range and assert on
|
/// `signedness`. Assume the width is within the allowed range and assert on
|
||||||
/// failures. Use getChecked to handle failures gracefully.
|
/// failures. Use getChecked to handle failures gracefully.
|
||||||
static IntegerType get(unsigned width, SignednessSemantics signedness,
|
static IntegerType get(MLIRContext *context, unsigned width,
|
||||||
MLIRContext *context);
|
SignednessSemantics signedness);
|
||||||
|
|
||||||
/// Get or create a new IntegerType of the given width within the context,
|
/// Get or create a new IntegerType of the given width within the context,
|
||||||
/// defined at the given, potentially unknown, location. The created
|
/// defined at the given, potentially unknown, location. The created
|
||||||
/// IntegerType is signless (i.e., no signedness semantics). If the width is
|
/// IntegerType is signless (i.e., no signedness semantics). If the width is
|
||||||
/// outside the allowed range, emit errors and return a null type.
|
/// outside the allowed range, emit errors and return a null type.
|
||||||
static IntegerType getChecked(unsigned width, Location location);
|
static IntegerType getChecked(Location location, unsigned width);
|
||||||
|
|
||||||
/// Get or create a new IntegerType of the given width within the context,
|
/// Get or create a new IntegerType of the given width within the context,
|
||||||
/// defined at the given, potentially unknown, location. The created
|
/// defined at the given, potentially unknown, location. The created
|
||||||
/// IntegerType has signedness semantics as indicated via `signedness`. If the
|
/// IntegerType has signedness semantics as indicated via `signedness`. If the
|
||||||
/// width is outside the allowed range, emit errors and return a null type.
|
/// width is outside the allowed range, emit errors and return a null type.
|
||||||
static IntegerType getChecked(unsigned width, SignednessSemantics signedness,
|
static IntegerType getChecked(Location location, unsigned width,
|
||||||
Location location);
|
SignednessSemantics signedness);
|
||||||
|
|
||||||
/// Verify the construction of an integer type.
|
/// Verify the construction of an integer type.
|
||||||
static LogicalResult
|
static LogicalResult
|
||||||
|
@ -180,8 +180,8 @@ class FunctionType
|
||||||
public:
|
public:
|
||||||
using Base::Base;
|
using Base::Base;
|
||||||
|
|
||||||
static FunctionType get(TypeRange inputs, TypeRange results,
|
static FunctionType get(MLIRContext *context, TypeRange inputs,
|
||||||
MLIRContext *context);
|
TypeRange results);
|
||||||
|
|
||||||
/// Input types.
|
/// Input types.
|
||||||
unsigned getNumInputs() const;
|
unsigned getNumInputs() const;
|
||||||
|
@ -211,14 +211,14 @@ public:
|
||||||
using Base::Base;
|
using Base::Base;
|
||||||
|
|
||||||
/// Get or create a new OpaqueType with the provided dialect and string data.
|
/// Get or create a new OpaqueType with the provided dialect and string data.
|
||||||
static OpaqueType get(Identifier dialect, StringRef typeData,
|
static OpaqueType get(MLIRContext *context, Identifier dialect,
|
||||||
MLIRContext *context);
|
StringRef typeData);
|
||||||
|
|
||||||
/// Get or create a new OpaqueType with the provided dialect and string data.
|
/// Get or create a new OpaqueType with the provided dialect and string data.
|
||||||
/// If the given identifier is not a valid namespace for a dialect, then a
|
/// If the given identifier is not a valid namespace for a dialect, then a
|
||||||
/// null type is returned.
|
/// null type is returned.
|
||||||
static OpaqueType getChecked(Identifier dialect, StringRef typeData,
|
static OpaqueType getChecked(Location location, Identifier dialect,
|
||||||
MLIRContext *context, Location location);
|
StringRef typeData);
|
||||||
|
|
||||||
/// Returns the dialect namespace of the opaque type.
|
/// Returns the dialect namespace of the opaque type.
|
||||||
Identifier getDialectNamespace() const;
|
Identifier getDialectNamespace() const;
|
||||||
|
@ -335,8 +335,8 @@ public:
|
||||||
/// declared at the given, potentially unknown, location. If the VectorType
|
/// declared at the given, potentially unknown, location. If the VectorType
|
||||||
/// defined by the arguments would be ill-formed, emit errors and return
|
/// defined by the arguments would be ill-formed, emit errors and return
|
||||||
/// nullptr-wrapping type.
|
/// nullptr-wrapping type.
|
||||||
static VectorType getChecked(ArrayRef<int64_t> shape, Type elementType,
|
static VectorType getChecked(Location location, ArrayRef<int64_t> shape,
|
||||||
Location location);
|
Type elementType);
|
||||||
|
|
||||||
/// Verify the construction of a vector type.
|
/// Verify the construction of a vector type.
|
||||||
static LogicalResult verifyConstructionInvariants(Location loc,
|
static LogicalResult verifyConstructionInvariants(Location loc,
|
||||||
|
@ -394,8 +394,8 @@ public:
|
||||||
/// type declared at the given, potentially unknown, location. If the
|
/// type declared at the given, potentially unknown, location. If the
|
||||||
/// RankedTensorType defined by the arguments would be ill-formed, emit errors
|
/// RankedTensorType defined by the arguments would be ill-formed, emit errors
|
||||||
/// and return a nullptr-wrapping type.
|
/// and return a nullptr-wrapping type.
|
||||||
static RankedTensorType getChecked(ArrayRef<int64_t> shape, Type elementType,
|
static RankedTensorType getChecked(Location location, ArrayRef<int64_t> shape,
|
||||||
Location location);
|
Type elementType);
|
||||||
|
|
||||||
/// Verify the construction of a ranked tensor type.
|
/// Verify the construction of a ranked tensor type.
|
||||||
static LogicalResult verifyConstructionInvariants(Location loc,
|
static LogicalResult verifyConstructionInvariants(Location loc,
|
||||||
|
@ -424,7 +424,7 @@ public:
|
||||||
/// type declared at the given, potentially unknown, location. If the
|
/// type declared at the given, potentially unknown, location. If the
|
||||||
/// UnrankedTensorType defined by the arguments would be ill-formed, emit
|
/// UnrankedTensorType defined by the arguments would be ill-formed, emit
|
||||||
/// errors and return a nullptr-wrapping type.
|
/// errors and return a nullptr-wrapping type.
|
||||||
static UnrankedTensorType getChecked(Type elementType, Location location);
|
static UnrankedTensorType getChecked(Location location, Type elementType);
|
||||||
|
|
||||||
/// Verify the construction of a unranked tensor type.
|
/// Verify the construction of a unranked tensor type.
|
||||||
static LogicalResult verifyConstructionInvariants(Location loc,
|
static LogicalResult verifyConstructionInvariants(Location loc,
|
||||||
|
@ -527,9 +527,10 @@ public:
|
||||||
/// UnknownLoc. If the MemRefType defined by the arguments would be
|
/// UnknownLoc. If the MemRefType defined by the arguments would be
|
||||||
/// ill-formed, emits errors (to the handler registered with the context or to
|
/// ill-formed, emits errors (to the handler registered with the context or to
|
||||||
/// the error stream) and returns nullptr.
|
/// the error stream) and returns nullptr.
|
||||||
static MemRefType getChecked(ArrayRef<int64_t> shape, Type elementType,
|
static MemRefType getChecked(Location location, ArrayRef<int64_t> shape,
|
||||||
|
Type elementType,
|
||||||
ArrayRef<AffineMap> affineMapComposition,
|
ArrayRef<AffineMap> affineMapComposition,
|
||||||
unsigned memorySpace, Location location);
|
unsigned memorySpace);
|
||||||
|
|
||||||
ArrayRef<int64_t> getShape() const;
|
ArrayRef<int64_t> getShape() const;
|
||||||
|
|
||||||
|
@ -573,8 +574,8 @@ public:
|
||||||
/// type and memory space declared at the given, potentially unknown,
|
/// type and memory space declared at the given, potentially unknown,
|
||||||
/// location. If the UnrankedMemRefType defined by the arguments would be
|
/// location. If the UnrankedMemRefType defined by the arguments would be
|
||||||
/// ill-formed, emit errors and return a nullptr-wrapping type.
|
/// ill-formed, emit errors and return a nullptr-wrapping type.
|
||||||
static UnrankedMemRefType getChecked(Type elementType, unsigned memorySpace,
|
static UnrankedMemRefType getChecked(Location location, Type elementType,
|
||||||
Location location);
|
unsigned memorySpace);
|
||||||
|
|
||||||
/// Verify the construction of a unranked memref type.
|
/// Verify the construction of a unranked memref type.
|
||||||
static LogicalResult verifyConstructionInvariants(Location loc,
|
static LogicalResult verifyConstructionInvariants(Location loc,
|
||||||
|
@ -600,7 +601,7 @@ public:
|
||||||
|
|
||||||
/// Get or create a new TupleType with the provided element types. Assumes the
|
/// Get or create a new TupleType with the provided element types. Assumes the
|
||||||
/// arguments define a well-formed type.
|
/// arguments define a well-formed type.
|
||||||
static TupleType get(TypeRange elementTypes, MLIRContext *context);
|
static TupleType get(MLIRContext *context, TypeRange elementTypes);
|
||||||
|
|
||||||
/// Get or create an empty tuple type.
|
/// Get or create an empty tuple type.
|
||||||
static TupleType get(MLIRContext *context);
|
static TupleType get(MLIRContext *context);
|
||||||
|
|
|
@ -475,8 +475,9 @@ def AnyComplex : Type<CPred<"$_self.isa<::mlir::ComplexType>()">,
|
||||||
class OpaqueType<string dialect, string name, string description>
|
class OpaqueType<string dialect, string name, string description>
|
||||||
: Type<CPred<"isOpaqueTypeWithName($_self, \""#dialect#"\", \""#name#"\")">,
|
: Type<CPred<"isOpaqueTypeWithName($_self, \""#dialect#"\", \""#name#"\")">,
|
||||||
description>,
|
description>,
|
||||||
BuildableType<"::mlir::OpaqueType::get($_builder.getIdentifier(\""
|
BuildableType<"::mlir::OpaqueType::get($_builder.getContext(), "
|
||||||
# dialect # "\"), \"" # name # "\", $_builder.getContext())">;
|
"$_builder.getIdentifier(\"" # dialect # "\"), \""
|
||||||
|
# name # "\")">;
|
||||||
|
|
||||||
// Function Type
|
// Function Type
|
||||||
|
|
||||||
|
|
|
@ -26,15 +26,15 @@ bool mlirTypeIsAInteger(MlirType type) {
|
||||||
}
|
}
|
||||||
|
|
||||||
MlirType mlirIntegerTypeGet(MlirContext ctx, unsigned bitwidth) {
|
MlirType mlirIntegerTypeGet(MlirContext ctx, unsigned bitwidth) {
|
||||||
return wrap(IntegerType::get(bitwidth, unwrap(ctx)));
|
return wrap(IntegerType::get(unwrap(ctx), bitwidth));
|
||||||
}
|
}
|
||||||
|
|
||||||
MlirType mlirIntegerTypeSignedGet(MlirContext ctx, unsigned bitwidth) {
|
MlirType mlirIntegerTypeSignedGet(MlirContext ctx, unsigned bitwidth) {
|
||||||
return wrap(IntegerType::get(bitwidth, IntegerType::Signed, unwrap(ctx)));
|
return wrap(IntegerType::get(unwrap(ctx), bitwidth, IntegerType::Signed));
|
||||||
}
|
}
|
||||||
|
|
||||||
MlirType mlirIntegerTypeUnsignedGet(MlirContext ctx, unsigned bitwidth) {
|
MlirType mlirIntegerTypeUnsignedGet(MlirContext ctx, unsigned bitwidth) {
|
||||||
return wrap(IntegerType::get(bitwidth, IntegerType::Unsigned, unwrap(ctx)));
|
return wrap(IntegerType::get(unwrap(ctx), bitwidth, IntegerType::Unsigned));
|
||||||
}
|
}
|
||||||
|
|
||||||
unsigned mlirIntegerTypeGetWidth(MlirType type) {
|
unsigned mlirIntegerTypeGetWidth(MlirType type) {
|
||||||
|
@ -172,8 +172,8 @@ MlirType mlirVectorTypeGet(intptr_t rank, const int64_t *shape,
|
||||||
MlirType mlirVectorTypeGetChecked(intptr_t rank, const int64_t *shape,
|
MlirType mlirVectorTypeGetChecked(intptr_t rank, const int64_t *shape,
|
||||||
MlirType elementType, MlirLocation loc) {
|
MlirType elementType, MlirLocation loc) {
|
||||||
return wrap(VectorType::getChecked(
|
return wrap(VectorType::getChecked(
|
||||||
llvm::makeArrayRef(shape, static_cast<size_t>(rank)), unwrap(elementType),
|
unwrap(loc), llvm::makeArrayRef(shape, static_cast<size_t>(rank)),
|
||||||
unwrap(loc)));
|
unwrap(elementType)));
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -201,8 +201,8 @@ MlirType mlirRankedTensorTypeGetChecked(intptr_t rank, const int64_t *shape,
|
||||||
MlirType elementType,
|
MlirType elementType,
|
||||||
MlirLocation loc) {
|
MlirLocation loc) {
|
||||||
return wrap(RankedTensorType::getChecked(
|
return wrap(RankedTensorType::getChecked(
|
||||||
llvm::makeArrayRef(shape, static_cast<size_t>(rank)), unwrap(elementType),
|
unwrap(loc), llvm::makeArrayRef(shape, static_cast<size_t>(rank)),
|
||||||
unwrap(loc)));
|
unwrap(elementType)));
|
||||||
}
|
}
|
||||||
|
|
||||||
MlirType mlirUnrankedTensorTypeGet(MlirType elementType) {
|
MlirType mlirUnrankedTensorTypeGet(MlirType elementType) {
|
||||||
|
@ -211,7 +211,7 @@ MlirType mlirUnrankedTensorTypeGet(MlirType elementType) {
|
||||||
|
|
||||||
MlirType mlirUnrankedTensorTypeGetChecked(MlirType elementType,
|
MlirType mlirUnrankedTensorTypeGetChecked(MlirType elementType,
|
||||||
MlirLocation loc) {
|
MlirLocation loc) {
|
||||||
return wrap(UnrankedTensorType::getChecked(unwrap(elementType), unwrap(loc)));
|
return wrap(UnrankedTensorType::getChecked(unwrap(loc), unwrap(elementType)));
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -244,8 +244,8 @@ MlirType mlirMemRefTypeContiguousGetChecked(MlirType elementType, intptr_t rank,
|
||||||
unsigned memorySpace,
|
unsigned memorySpace,
|
||||||
MlirLocation loc) {
|
MlirLocation loc) {
|
||||||
return wrap(MemRefType::getChecked(
|
return wrap(MemRefType::getChecked(
|
||||||
llvm::makeArrayRef(shape, static_cast<size_t>(rank)), unwrap(elementType),
|
unwrap(loc), llvm::makeArrayRef(shape, static_cast<size_t>(rank)),
|
||||||
llvm::None, memorySpace, unwrap(loc)));
|
unwrap(elementType), llvm::None, memorySpace));
|
||||||
}
|
}
|
||||||
|
|
||||||
intptr_t mlirMemRefTypeGetNumAffineMaps(MlirType type) {
|
intptr_t mlirMemRefTypeGetNumAffineMaps(MlirType type) {
|
||||||
|
@ -272,8 +272,8 @@ MlirType mlirUnrankedMemRefTypeGet(MlirType elementType, unsigned memorySpace) {
|
||||||
MlirType mlirUnrankedMemRefTypeGetChecked(MlirType elementType,
|
MlirType mlirUnrankedMemRefTypeGetChecked(MlirType elementType,
|
||||||
unsigned memorySpace,
|
unsigned memorySpace,
|
||||||
MlirLocation loc) {
|
MlirLocation loc) {
|
||||||
return wrap(UnrankedMemRefType::getChecked(unwrap(elementType), memorySpace,
|
return wrap(UnrankedMemRefType::getChecked(unwrap(loc), unwrap(elementType),
|
||||||
unwrap(loc)));
|
memorySpace));
|
||||||
}
|
}
|
||||||
|
|
||||||
unsigned mlirUnrankedMemrefGetMemorySpace(MlirType type) {
|
unsigned mlirUnrankedMemrefGetMemorySpace(MlirType type) {
|
||||||
|
@ -290,7 +290,7 @@ MlirType mlirTupleTypeGet(MlirContext ctx, intptr_t numElements,
|
||||||
MlirType const *elements) {
|
MlirType const *elements) {
|
||||||
SmallVector<Type, 4> types;
|
SmallVector<Type, 4> types;
|
||||||
ArrayRef<Type> typeRef = unwrapList(numElements, elements, types);
|
ArrayRef<Type> typeRef = unwrapList(numElements, elements, types);
|
||||||
return wrap(TupleType::get(typeRef, unwrap(ctx)));
|
return wrap(TupleType::get(unwrap(ctx), typeRef));
|
||||||
}
|
}
|
||||||
|
|
||||||
intptr_t mlirTupleTypeGetNumTypes(MlirType type) {
|
intptr_t mlirTupleTypeGetNumTypes(MlirType type) {
|
||||||
|
@ -316,7 +316,7 @@ MlirType mlirFunctionTypeGet(MlirContext ctx, intptr_t numInputs,
|
||||||
SmallVector<Type, 4> resultsList;
|
SmallVector<Type, 4> resultsList;
|
||||||
(void)unwrapList(numInputs, inputs, inputsList);
|
(void)unwrapList(numInputs, inputs, inputsList);
|
||||||
(void)unwrapList(numResults, results, resultsList);
|
(void)unwrapList(numResults, results, resultsList);
|
||||||
return wrap(FunctionType::get(inputsList, resultsList, unwrap(ctx)));
|
return wrap(FunctionType::get(unwrap(ctx), inputsList, resultsList));
|
||||||
}
|
}
|
||||||
|
|
||||||
intptr_t mlirFunctionTypeGetNumInputs(MlirType type) {
|
intptr_t mlirFunctionTypeGetNumInputs(MlirType type) {
|
||||||
|
|
|
@ -53,52 +53,52 @@ namespace {
|
||||||
struct AsyncAPI {
|
struct AsyncAPI {
|
||||||
static FunctionType addOrDropRefFunctionType(MLIRContext *ctx) {
|
static FunctionType addOrDropRefFunctionType(MLIRContext *ctx) {
|
||||||
auto ref = LLVM::LLVMType::getInt8PtrTy(ctx);
|
auto ref = LLVM::LLVMType::getInt8PtrTy(ctx);
|
||||||
auto count = IntegerType::get(32, ctx);
|
auto count = IntegerType::get(ctx, 32);
|
||||||
return FunctionType::get({ref, count}, {}, ctx);
|
return FunctionType::get(ctx, {ref, count}, {});
|
||||||
}
|
}
|
||||||
|
|
||||||
static FunctionType createTokenFunctionType(MLIRContext *ctx) {
|
static FunctionType createTokenFunctionType(MLIRContext *ctx) {
|
||||||
return FunctionType::get({}, {TokenType::get(ctx)}, ctx);
|
return FunctionType::get(ctx, {}, {TokenType::get(ctx)});
|
||||||
}
|
}
|
||||||
|
|
||||||
static FunctionType createGroupFunctionType(MLIRContext *ctx) {
|
static FunctionType createGroupFunctionType(MLIRContext *ctx) {
|
||||||
return FunctionType::get({}, {GroupType::get(ctx)}, ctx);
|
return FunctionType::get(ctx, {}, {GroupType::get(ctx)});
|
||||||
}
|
}
|
||||||
|
|
||||||
static FunctionType emplaceTokenFunctionType(MLIRContext *ctx) {
|
static FunctionType emplaceTokenFunctionType(MLIRContext *ctx) {
|
||||||
return FunctionType::get({TokenType::get(ctx)}, {}, ctx);
|
return FunctionType::get(ctx, {TokenType::get(ctx)}, {});
|
||||||
}
|
}
|
||||||
|
|
||||||
static FunctionType awaitTokenFunctionType(MLIRContext *ctx) {
|
static FunctionType awaitTokenFunctionType(MLIRContext *ctx) {
|
||||||
return FunctionType::get({TokenType::get(ctx)}, {}, ctx);
|
return FunctionType::get(ctx, {TokenType::get(ctx)}, {});
|
||||||
}
|
}
|
||||||
|
|
||||||
static FunctionType awaitGroupFunctionType(MLIRContext *ctx) {
|
static FunctionType awaitGroupFunctionType(MLIRContext *ctx) {
|
||||||
return FunctionType::get({GroupType::get(ctx)}, {}, ctx);
|
return FunctionType::get(ctx, {GroupType::get(ctx)}, {});
|
||||||
}
|
}
|
||||||
|
|
||||||
static FunctionType executeFunctionType(MLIRContext *ctx) {
|
static FunctionType executeFunctionType(MLIRContext *ctx) {
|
||||||
auto hdl = LLVM::LLVMType::getInt8PtrTy(ctx);
|
auto hdl = LLVM::LLVMType::getInt8PtrTy(ctx);
|
||||||
auto resume = resumeFunctionType(ctx).getPointerTo();
|
auto resume = resumeFunctionType(ctx).getPointerTo();
|
||||||
return FunctionType::get({hdl, resume}, {}, ctx);
|
return FunctionType::get(ctx, {hdl, resume}, {});
|
||||||
}
|
}
|
||||||
|
|
||||||
static FunctionType addTokenToGroupFunctionType(MLIRContext *ctx) {
|
static FunctionType addTokenToGroupFunctionType(MLIRContext *ctx) {
|
||||||
auto i64 = IntegerType::get(64, ctx);
|
auto i64 = IntegerType::get(ctx, 64);
|
||||||
return FunctionType::get({TokenType::get(ctx), GroupType::get(ctx)}, {i64},
|
return FunctionType::get(ctx, {TokenType::get(ctx), GroupType::get(ctx)},
|
||||||
ctx);
|
{i64});
|
||||||
}
|
}
|
||||||
|
|
||||||
static FunctionType awaitAndExecuteFunctionType(MLIRContext *ctx) {
|
static FunctionType awaitAndExecuteFunctionType(MLIRContext *ctx) {
|
||||||
auto hdl = LLVM::LLVMType::getInt8PtrTy(ctx);
|
auto hdl = LLVM::LLVMType::getInt8PtrTy(ctx);
|
||||||
auto resume = resumeFunctionType(ctx).getPointerTo();
|
auto resume = resumeFunctionType(ctx).getPointerTo();
|
||||||
return FunctionType::get({TokenType::get(ctx), hdl, resume}, {}, ctx);
|
return FunctionType::get(ctx, {TokenType::get(ctx), hdl, resume}, {});
|
||||||
}
|
}
|
||||||
|
|
||||||
static FunctionType awaitAllAndExecuteFunctionType(MLIRContext *ctx) {
|
static FunctionType awaitAllAndExecuteFunctionType(MLIRContext *ctx) {
|
||||||
auto hdl = LLVM::LLVMType::getInt8PtrTy(ctx);
|
auto hdl = LLVM::LLVMType::getInt8PtrTy(ctx);
|
||||||
auto resume = resumeFunctionType(ctx).getPointerTo();
|
auto resume = resumeFunctionType(ctx).getPointerTo();
|
||||||
return FunctionType::get({GroupType::get(ctx), hdl, resume}, {}, ctx);
|
return FunctionType::get(ctx, {GroupType::get(ctx), hdl, resume}, {});
|
||||||
}
|
}
|
||||||
|
|
||||||
// Auxiliary coroutine resume intrinsic wrapper.
|
// Auxiliary coroutine resume intrinsic wrapper.
|
||||||
|
@ -690,7 +690,7 @@ public:
|
||||||
if (!addToGroup.operand().getType().isa<TokenType>())
|
if (!addToGroup.operand().getType().isa<TokenType>())
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
auto i64 = IntegerType::get(64, op->getContext());
|
auto i64 = IntegerType::get(op->getContext(), 64);
|
||||||
rewriter.replaceOpWithNewOp<CallOp>(op, kAddTokenToGroup, i64, operands);
|
rewriter.replaceOpWithNewOp<CallOp>(op, kAddTokenToGroup, i64, operands);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
|
@ -122,7 +122,7 @@ LogicalResult ConvertGpuLaunchFuncToVulkanLaunchFunc::declareVulkanLaunchFunc(
|
||||||
}
|
}
|
||||||
|
|
||||||
// Declare vulkan launch function.
|
// Declare vulkan launch function.
|
||||||
auto funcType = FunctionType::get(vulkanLaunchTypes, {}, loc->getContext());
|
auto funcType = builder.getFunctionType(vulkanLaunchTypes, {});
|
||||||
builder.create<FuncOp>(loc, kVulkanLaunch, funcType).setPrivate();
|
builder.create<FuncOp>(loc, kVulkanLaunch, funcType).setPrivate();
|
||||||
|
|
||||||
return success();
|
return success();
|
||||||
|
|
|
@ -84,7 +84,7 @@ static LLVMType getPtrToElementType(T containerType,
|
||||||
/// };
|
/// };
|
||||||
static Type convertRangeType(RangeType t, LLVMTypeConverter &converter) {
|
static Type convertRangeType(RangeType t, LLVMTypeConverter &converter) {
|
||||||
auto *context = t.getContext();
|
auto *context = t.getContext();
|
||||||
auto int64Ty = converter.convertType(IntegerType::get(64, context))
|
auto int64Ty = converter.convertType(IntegerType::get(context, 64))
|
||||||
.cast<LLVM::LLVMType>();
|
.cast<LLVM::LLVMType>();
|
||||||
return LLVMType::getStructTy(int64Ty, int64Ty, int64Ty);
|
return LLVMType::getStructTy(int64Ty, int64Ty, int64Ty);
|
||||||
}
|
}
|
||||||
|
|
|
@ -65,7 +65,7 @@ static FlatSymbolRefAttr getLibraryCallSymbolRef(Operation *op,
|
||||||
assert(op->getNumResults() == 0 &&
|
assert(op->getNumResults() == 0 &&
|
||||||
"Library call for linalg operation can be generated only for ops that "
|
"Library call for linalg operation can be generated only for ops that "
|
||||||
"have void return types");
|
"have void return types");
|
||||||
auto libFnType = FunctionType::get(inputTypes, {}, rewriter.getContext());
|
auto libFnType = rewriter.getFunctionType(inputTypes, {});
|
||||||
|
|
||||||
OpBuilder::InsertionGuard guard(rewriter);
|
OpBuilder::InsertionGuard guard(rewriter);
|
||||||
// Insert before module terminator.
|
// Insert before module terminator.
|
||||||
|
|
|
@ -407,8 +407,7 @@ public:
|
||||||
// cover all possible corner cases.
|
// cover all possible corner cases.
|
||||||
if (isSignedIntegerOrVector(srcType) ||
|
if (isSignedIntegerOrVector(srcType) ||
|
||||||
isUnsignedIntegerOrVector(srcType)) {
|
isUnsignedIntegerOrVector(srcType)) {
|
||||||
auto *context = rewriter.getContext();
|
auto signlessType = rewriter.getIntegerType(getBitWidth(srcType));
|
||||||
auto signlessType = IntegerType::get(getBitWidth(srcType), context);
|
|
||||||
|
|
||||||
if (srcType.isa<VectorType>()) {
|
if (srcType.isa<VectorType>()) {
|
||||||
auto dstElementsAttr = constOp.value().cast<DenseIntElementsAttr>();
|
auto dstElementsAttr = constOp.value().cast<DenseIntElementsAttr>();
|
||||||
|
|
|
@ -584,7 +584,7 @@ LogicalResult VectorTransferRewriter<TransferReadOp>::matchAndRewrite(
|
||||||
std::swap(ivsStorage.back(), ivsStorage[coalescedIdx]);
|
std::swap(ivsStorage.back(), ivsStorage[coalescedIdx]);
|
||||||
|
|
||||||
ArrayRef<Value> ivs(ivsStorage);
|
ArrayRef<Value> ivs(ivsStorage);
|
||||||
Value pos = std_index_cast(IntegerType::get(32, ctx), ivs.back());
|
Value pos = std_index_cast(IntegerType::get(ctx, 32), ivs.back());
|
||||||
Value inVector = local(ivs.drop_back());
|
Value inVector = local(ivs.drop_back());
|
||||||
auto loadValue = [&](ArrayRef<Value> indices) {
|
auto loadValue = [&](ArrayRef<Value> indices) {
|
||||||
Value vector = vector_insert_element(remote(indices), inVector, pos);
|
Value vector = vector_insert_element(remote(indices), inVector, pos);
|
||||||
|
@ -671,7 +671,7 @@ LogicalResult VectorTransferRewriter<TransferWriteOp>::matchAndRewrite(
|
||||||
|
|
||||||
ArrayRef<Value> ivs(ivsStorage);
|
ArrayRef<Value> ivs(ivsStorage);
|
||||||
Value pos =
|
Value pos =
|
||||||
std_index_cast(IntegerType::get(32, op->getContext()), ivs.back());
|
std_index_cast(IntegerType::get(op->getContext(), 32), ivs.back());
|
||||||
auto storeValue = [&](ArrayRef<Value> indices) {
|
auto storeValue = [&](ArrayRef<Value> indices) {
|
||||||
Value scalar = vector_extract_element(local(ivs.drop_back()), pos);
|
Value scalar = vector_extract_element(local(ivs.drop_back()), pos);
|
||||||
remote(indices) = scalar;
|
remote(indices) = scalar;
|
||||||
|
|
|
@ -152,7 +152,7 @@ void ExecuteOp::build(OpBuilder &builder, OperationState &result,
|
||||||
int32_t numDependencies = dependencies.size();
|
int32_t numDependencies = dependencies.size();
|
||||||
int32_t numOperands = operands.size();
|
int32_t numOperands = operands.size();
|
||||||
auto operandSegmentSizes = DenseIntElementsAttr::get(
|
auto operandSegmentSizes = DenseIntElementsAttr::get(
|
||||||
VectorType::get({2}, IntegerType::get(32, result.getContext())),
|
VectorType::get({2}, builder.getIntegerType(32)),
|
||||||
{numDependencies, numOperands});
|
{numDependencies, numOperands});
|
||||||
result.addAttribute(kOperandSegmentSizesAttr, operandSegmentSizes);
|
result.addAttribute(kOperandSegmentSizesAttr, operandSegmentSizes);
|
||||||
|
|
||||||
|
|
|
@ -118,7 +118,7 @@ LogicalResult AsyncRefCountingPass::addAutomaticRefCounting(Value value) {
|
||||||
builder.setInsertionPointToStart(value.getParentBlock());
|
builder.setInsertionPointToStart(value.getParentBlock());
|
||||||
|
|
||||||
Location loc = value.getLoc();
|
Location loc = value.getLoc();
|
||||||
auto i32 = IntegerType::get(32, ctx);
|
auto i32 = IntegerType::get(ctx, 32);
|
||||||
|
|
||||||
// Drop the reference count immediately if the value has no uses.
|
// Drop the reference count immediately if the value has no uses.
|
||||||
if (value.getUses().empty()) {
|
if (value.getUses().empty()) {
|
||||||
|
|
|
@ -31,7 +31,7 @@ struct GpuAllReduceRewriter {
|
||||||
: funcOp(funcOp_), reduceOp(reduceOp_), rewriter(rewriter_),
|
: funcOp(funcOp_), reduceOp(reduceOp_), rewriter(rewriter_),
|
||||||
loc(reduceOp.getLoc()), valueType(reduceOp.value().getType()),
|
loc(reduceOp.getLoc()), valueType(reduceOp.value().getType()),
|
||||||
indexType(IndexType::get(reduceOp.getContext())),
|
indexType(IndexType::get(reduceOp.getContext())),
|
||||||
int32Type(IntegerType::get(/*width=*/32, reduceOp.getContext())) {}
|
int32Type(IntegerType::get(reduceOp.getContext(), /*width=*/32)) {}
|
||||||
|
|
||||||
/// Creates an all_reduce across the workgroup.
|
/// Creates an all_reduce across the workgroup.
|
||||||
///
|
///
|
||||||
|
|
|
@ -155,7 +155,7 @@ static gpu::GPUFuncOp outlineKernelFuncImpl(gpu::LaunchOp launchOp,
|
||||||
kernelOperandTypes.push_back(operand.getType());
|
kernelOperandTypes.push_back(operand.getType());
|
||||||
}
|
}
|
||||||
FunctionType type =
|
FunctionType type =
|
||||||
FunctionType::get(kernelOperandTypes, {}, launchOp.getContext());
|
FunctionType::get(launchOp.getContext(), kernelOperandTypes, {});
|
||||||
auto outlinedFunc = builder.create<gpu::GPUFuncOp>(loc, kernelFnName, type);
|
auto outlinedFunc = builder.create<gpu::GPUFuncOp>(loc, kernelFnName, type);
|
||||||
outlinedFunc->setAttr(gpu::GPUDialect::getKernelFuncAttrName(),
|
outlinedFunc->setAttr(gpu::GPUDialect::getKernelFuncAttrName(),
|
||||||
builder.getUnitAttr());
|
builder.getUnitAttr());
|
||||||
|
|
|
@ -120,8 +120,8 @@ static ParseResult parseCmpOp(OpAsmParser &parser, OperationState &result) {
|
||||||
static void printAllocaOp(OpAsmPrinter &p, AllocaOp &op) {
|
static void printAllocaOp(OpAsmPrinter &p, AllocaOp &op) {
|
||||||
auto elemTy = op.getType().cast<LLVM::LLVMType>().getPointerElementTy();
|
auto elemTy = op.getType().cast<LLVM::LLVMType>().getPointerElementTy();
|
||||||
|
|
||||||
auto funcTy = FunctionType::get({op.arraySize().getType()}, {op.getType()},
|
auto funcTy = FunctionType::get(op.getContext(), {op.arraySize().getType()},
|
||||||
op.getContext());
|
{op.getType()});
|
||||||
|
|
||||||
p << op.getOperationName() << ' ' << op.arraySize() << " x " << elemTy;
|
p << op.getOperationName() << ' ' << op.arraySize() << " x " << elemTy;
|
||||||
if (op.alignment().hasValue() && *op.alignment() != 0)
|
if (op.alignment().hasValue() && *op.alignment() != 0)
|
||||||
|
@ -781,7 +781,7 @@ static void printCallOp(OpAsmPrinter &p, CallOp &op) {
|
||||||
|
|
||||||
// Reconstruct the function MLIR function type from operand and result types.
|
// Reconstruct the function MLIR function type from operand and result types.
|
||||||
p << " : "
|
p << " : "
|
||||||
<< FunctionType::get(args.getTypes(), op.getResultTypes(), op.getContext());
|
<< FunctionType::get(op.getContext(), args.getTypes(), op.getResultTypes());
|
||||||
}
|
}
|
||||||
|
|
||||||
// <operation> ::= `llvm.call` (function-id | ssa-use) `(` ssa-use-list `)`
|
// <operation> ::= `llvm.call` (function-id | ssa-use) `(` ssa-use-list `)`
|
||||||
|
|
|
@ -76,25 +76,25 @@ static Value allocBuffer(const LinalgPromotionOptions &options,
|
||||||
IntegerAttr alignment_attr;
|
IntegerAttr alignment_attr;
|
||||||
if (alignment.hasValue())
|
if (alignment.hasValue())
|
||||||
alignment_attr =
|
alignment_attr =
|
||||||
IntegerAttr::get(IntegerType::get(64, ctx), alignment.getValue());
|
IntegerAttr::get(IntegerType::get(ctx, 64), alignment.getValue());
|
||||||
if (!dynamicBuffers)
|
if (!dynamicBuffers)
|
||||||
if (auto cst = size.getDefiningOp<ConstantIndexOp>())
|
if (auto cst = size.getDefiningOp<ConstantIndexOp>())
|
||||||
return options.useAlloca
|
return options.useAlloca
|
||||||
? std_alloca(MemRefType::get(width * cst.getValue(),
|
? std_alloca(MemRefType::get(width * cst.getValue(),
|
||||||
IntegerType::get(8, ctx)),
|
IntegerType::get(ctx, 8)),
|
||||||
ValueRange{}, alignment_attr)
|
ValueRange{}, alignment_attr)
|
||||||
.value
|
.value
|
||||||
: std_alloc(MemRefType::get(width * cst.getValue(),
|
: std_alloc(MemRefType::get(width * cst.getValue(),
|
||||||
IntegerType::get(8, ctx)),
|
IntegerType::get(ctx, 8)),
|
||||||
ValueRange{}, alignment_attr)
|
ValueRange{}, alignment_attr)
|
||||||
.value;
|
.value;
|
||||||
Value mul =
|
Value mul =
|
||||||
folded_std_muli(folder, folded_std_constant_index(folder, width), size);
|
folded_std_muli(folder, folded_std_constant_index(folder, width), size);
|
||||||
return options.useAlloca
|
return options.useAlloca
|
||||||
? std_alloca(MemRefType::get(-1, IntegerType::get(8, ctx)), mul,
|
? std_alloca(MemRefType::get(-1, IntegerType::get(ctx, 8)), mul,
|
||||||
alignment_attr)
|
alignment_attr)
|
||||||
.value
|
.value
|
||||||
: std_alloc(MemRefType::get(-1, IntegerType::get(8, ctx)), mul,
|
: std_alloc(MemRefType::get(-1, IntegerType::get(ctx, 8)), mul,
|
||||||
alignment_attr)
|
alignment_attr)
|
||||||
.value;
|
.value;
|
||||||
}
|
}
|
||||||
|
|
|
@ -18,7 +18,7 @@ static bool getDefaultStorageParams(unsigned numBits, bool narrowRange,
|
||||||
int64_t &qmax) {
|
int64_t &qmax) {
|
||||||
// Hard-coded type mapping from TFLite.
|
// Hard-coded type mapping from TFLite.
|
||||||
if (numBits <= 8) {
|
if (numBits <= 8) {
|
||||||
storageType = IntegerType::get(8, ctx);
|
storageType = IntegerType::get(ctx, 8);
|
||||||
if (isSigned) {
|
if (isSigned) {
|
||||||
qmin = -128;
|
qmin = -128;
|
||||||
qmax = 127;
|
qmax = 127;
|
||||||
|
@ -27,7 +27,7 @@ static bool getDefaultStorageParams(unsigned numBits, bool narrowRange,
|
||||||
qmax = 255;
|
qmax = 255;
|
||||||
}
|
}
|
||||||
} else if (numBits <= 16) {
|
} else if (numBits <= 16) {
|
||||||
storageType = IntegerType::get(16, ctx);
|
storageType = IntegerType::get(ctx, 16);
|
||||||
if (isSigned) {
|
if (isSigned) {
|
||||||
qmin = -32768;
|
qmin = -32768;
|
||||||
qmax = 32767;
|
qmax = 32767;
|
||||||
|
@ -36,7 +36,7 @@ static bool getDefaultStorageParams(unsigned numBits, bool narrowRange,
|
||||||
qmax = 65535;
|
qmax = 65535;
|
||||||
}
|
}
|
||||||
} else if (numBits <= 32) {
|
} else if (numBits <= 32) {
|
||||||
storageType = IntegerType::get(32, ctx);
|
storageType = IntegerType::get(ctx, 32);
|
||||||
if (isSigned) {
|
if (isSigned) {
|
||||||
qmin = std::numeric_limits<int32_t>::min();
|
qmin = std::numeric_limits<int32_t>::min();
|
||||||
qmax = std::numeric_limits<int32_t>::max();
|
qmax = std::numeric_limits<int32_t>::max();
|
||||||
|
|
|
@ -79,7 +79,7 @@ UniformQuantizedPerAxisValueConverter::convert(DenseFPElementsAttr attr) {
|
||||||
int64_t chunkSize =
|
int64_t chunkSize =
|
||||||
std::accumulate(std::next(shape.begin(), quantizationDim + 1),
|
std::accumulate(std::next(shape.begin(), quantizationDim + 1),
|
||||||
shape.end(), 1, std::multiplies<int64_t>());
|
shape.end(), 1, std::multiplies<int64_t>());
|
||||||
Type newElementType = IntegerType::get(storageBitWidth, attr.getContext());
|
Type newElementType = IntegerType::get(attr.getContext(), storageBitWidth);
|
||||||
return attr.mapValues(newElementType, [&](const APFloat &old) {
|
return attr.mapValues(newElementType, [&](const APFloat &old) {
|
||||||
int chunkIndex = (flattenIndex++) / chunkSize;
|
int chunkIndex = (flattenIndex++) / chunkSize;
|
||||||
return converters[chunkIndex % dimSize].quantizeFloatToInt(old);
|
return converters[chunkIndex % dimSize].quantizeFloatToInt(old);
|
||||||
|
|
|
@ -96,7 +96,7 @@ void mlir::outlineIfOp(OpBuilder &b, scf::IfOp ifOp, FuncOp *thenFn,
|
||||||
|
|
||||||
ValueRange values(captures.getArrayRef());
|
ValueRange values(captures.getArrayRef());
|
||||||
FunctionType type =
|
FunctionType type =
|
||||||
FunctionType::get(values.getTypes(), ifOp.getResultTypes(), ctx);
|
FunctionType::get(ctx, values.getTypes(), ifOp.getResultTypes());
|
||||||
auto outlinedFunc = b.create<FuncOp>(loc, funcName, type);
|
auto outlinedFunc = b.create<FuncOp>(loc, funcName, type);
|
||||||
b.setInsertionPointToStart(outlinedFunc.addEntryBlock());
|
b.setInsertionPointToStart(outlinedFunc.addEntryBlock());
|
||||||
BlockAndValueMapping bvm;
|
BlockAndValueMapping bvm;
|
||||||
|
|
|
@ -123,7 +123,7 @@ spirv::getEntryPointABIAttr(ArrayRef<int32_t> localSize, MLIRContext *context) {
|
||||||
assert(localSize.size() == 3);
|
assert(localSize.size() == 3);
|
||||||
return spirv::EntryPointABIAttr::get(
|
return spirv::EntryPointABIAttr::get(
|
||||||
DenseElementsAttr::get<int32_t>(
|
DenseElementsAttr::get<int32_t>(
|
||||||
VectorType::get(3, IntegerType::get(32, context)), localSize)
|
VectorType::get(3, IntegerType::get(context, 32)), localSize)
|
||||||
.cast<DenseIntElementsAttr>(),
|
.cast<DenseIntElementsAttr>(),
|
||||||
context);
|
context);
|
||||||
}
|
}
|
||||||
|
|
|
@ -93,7 +93,7 @@ Type SPIRVTypeConverter::getIndexType(MLIRContext *context) {
|
||||||
// instructions. The Vulkan spec requires the builtins like
|
// instructions. The Vulkan spec requires the builtins like
|
||||||
// GlobalInvocationID, etc. to be 32-bit (unsigned) integers which should be
|
// GlobalInvocationID, etc. to be 32-bit (unsigned) integers which should be
|
||||||
// SExtended to 64-bit for index computations.
|
// SExtended to 64-bit for index computations.
|
||||||
return IntegerType::get(32, context);
|
return IntegerType::get(context, 32);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Mapping between SPIR-V storage classes to memref memory spaces.
|
/// Mapping between SPIR-V storage classes to memref memory spaces.
|
||||||
|
@ -260,8 +260,8 @@ convertScalarType(const spirv::TargetEnv &targetEnv, spirv::ScalarType type,
|
||||||
|
|
||||||
auto intType = type.cast<IntegerType>();
|
auto intType = type.cast<IntegerType>();
|
||||||
LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n");
|
LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n");
|
||||||
return IntegerType::get(/*width=*/32, intType.getSignedness(),
|
return IntegerType::get(targetEnv.getContext(), /*width=*/32,
|
||||||
targetEnv.getContext());
|
intType.getSignedness());
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Converts a vector `type` to a suitable type under the given `targetEnv`.
|
/// Converts a vector `type` to a suitable type under the given `targetEnv`.
|
||||||
|
|
|
@ -714,7 +714,7 @@ LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
|
||||||
}
|
}
|
||||||
|
|
||||||
FunctionType CallOp::getCalleeType() {
|
FunctionType CallOp::getCalleeType() {
|
||||||
return FunctionType::get(getOperandTypes(), getResultTypes(), getContext());
|
return FunctionType::get(getContext(), getOperandTypes(), getResultTypes());
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -753,7 +753,7 @@ void CallIndirectOp::getCanonicalizationPatterns(
|
||||||
|
|
||||||
// Return the type of the same shape (scalar, vector or tensor) containing i1.
|
// Return the type of the same shape (scalar, vector or tensor) containing i1.
|
||||||
static Type getI1SameShape(Type type) {
|
static Type getI1SameShape(Type type) {
|
||||||
auto i1Type = IntegerType::get(1, type.getContext());
|
auto i1Type = IntegerType::get(type.getContext(), 1);
|
||||||
if (auto tensorType = type.dyn_cast<RankedTensorType>())
|
if (auto tensorType = type.dyn_cast<RankedTensorType>())
|
||||||
return RankedTensorType::get(tensorType.getShape(), i1Type);
|
return RankedTensorType::get(tensorType.getShape(), i1Type);
|
||||||
if (type.isa<UnrankedTensorType>())
|
if (type.isa<UnrankedTensorType>())
|
||||||
|
@ -914,7 +914,7 @@ OpFoldResult CmpFOp::fold(ArrayRef<Attribute> operands) {
|
||||||
return {};
|
return {};
|
||||||
|
|
||||||
auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
|
auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
|
||||||
return IntegerAttr::get(IntegerType::get(1, getContext()), APInt(1, val));
|
return IntegerAttr::get(IntegerType::get(getContext(), 1), APInt(1, val));
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -1426,7 +1426,7 @@ static LogicalResult isSumOfIntegerArrayAttrConfinedToShape(
|
||||||
static ArrayAttr makeI64ArrayAttr(ArrayRef<int64_t> values,
|
static ArrayAttr makeI64ArrayAttr(ArrayRef<int64_t> values,
|
||||||
MLIRContext *context) {
|
MLIRContext *context) {
|
||||||
auto attrs = llvm::map_range(values, [context](int64_t v) -> Attribute {
|
auto attrs = llvm::map_range(values, [context](int64_t v) -> Attribute {
|
||||||
return IntegerAttr::get(IntegerType::get(64, context), APInt(64, v));
|
return IntegerAttr::get(IntegerType::get(context, 64), APInt(64, v));
|
||||||
});
|
});
|
||||||
return ArrayAttr::get(llvm::to_vector<8>(attrs), context);
|
return ArrayAttr::get(llvm::to_vector<8>(attrs), context);
|
||||||
}
|
}
|
||||||
|
@ -2767,7 +2767,7 @@ static ParseResult parseTupleOp(OpAsmParser &parser, OperationState &result) {
|
||||||
parser.parseOptionalAttrDict(result.attributes) ||
|
parser.parseOptionalAttrDict(result.attributes) ||
|
||||||
parser.parseColonTypeList(types) ||
|
parser.parseColonTypeList(types) ||
|
||||||
parser.resolveOperands(operandInfos, types, loc, result.operands) ||
|
parser.resolveOperands(operandInfos, types, loc, result.operands) ||
|
||||||
parser.addTypeToList(TupleType::get(types, ctx), result.types));
|
parser.addTypeToList(TupleType::get(ctx, types), result.types));
|
||||||
}
|
}
|
||||||
|
|
||||||
static void print(OpAsmPrinter &p, TupleOp op) {
|
static void print(OpAsmPrinter &p, TupleOp op) {
|
||||||
|
|
|
@ -215,7 +215,7 @@ static TupleType generateExtractSlicesOpResultType(VectorType vectorType,
|
||||||
// Create Vector type and add to 'vectorTypes[i]'.
|
// Create Vector type and add to 'vectorTypes[i]'.
|
||||||
vectorTypes[i] = VectorType::get(sliceSizes, vectorType.getElementType());
|
vectorTypes[i] = VectorType::get(sliceSizes, vectorType.getElementType());
|
||||||
}
|
}
|
||||||
return TupleType::get(vectorTypes, builder.getContext());
|
return builder.getTupleType(vectorTypes);
|
||||||
}
|
}
|
||||||
|
|
||||||
// UnrolledVectorState aggregates per-operand/result vector state required for
|
// UnrolledVectorState aggregates per-operand/result vector state required for
|
||||||
|
|
|
@ -52,27 +52,27 @@ FloatType Builder::getF64Type() { return FloatType::getF64(context); }
|
||||||
|
|
||||||
IndexType Builder::getIndexType() { return IndexType::get(context); }
|
IndexType Builder::getIndexType() { return IndexType::get(context); }
|
||||||
|
|
||||||
IntegerType Builder::getI1Type() { return IntegerType::get(1, context); }
|
IntegerType Builder::getI1Type() { return IntegerType::get(context, 1); }
|
||||||
|
|
||||||
IntegerType Builder::getI32Type() { return IntegerType::get(32, context); }
|
IntegerType Builder::getI32Type() { return IntegerType::get(context, 32); }
|
||||||
|
|
||||||
IntegerType Builder::getI64Type() { return IntegerType::get(64, context); }
|
IntegerType Builder::getI64Type() { return IntegerType::get(context, 64); }
|
||||||
|
|
||||||
IntegerType Builder::getIntegerType(unsigned width) {
|
IntegerType Builder::getIntegerType(unsigned width) {
|
||||||
return IntegerType::get(width, context);
|
return IntegerType::get(context, width);
|
||||||
}
|
}
|
||||||
|
|
||||||
IntegerType Builder::getIntegerType(unsigned width, bool isSigned) {
|
IntegerType Builder::getIntegerType(unsigned width, bool isSigned) {
|
||||||
return IntegerType::get(
|
return IntegerType::get(
|
||||||
width, isSigned ? IntegerType::Signed : IntegerType::Unsigned, context);
|
context, width, isSigned ? IntegerType::Signed : IntegerType::Unsigned);
|
||||||
}
|
}
|
||||||
|
|
||||||
FunctionType Builder::getFunctionType(TypeRange inputs, TypeRange results) {
|
FunctionType Builder::getFunctionType(TypeRange inputs, TypeRange results) {
|
||||||
return FunctionType::get(inputs, results, context);
|
return FunctionType::get(context, inputs, results);
|
||||||
}
|
}
|
||||||
|
|
||||||
TupleType Builder::getTupleType(TypeRange elementTypes) {
|
TupleType Builder::getTupleType(TypeRange elementTypes) {
|
||||||
return TupleType::get(elementTypes, context);
|
return TupleType::get(context, elementTypes);
|
||||||
}
|
}
|
||||||
|
|
||||||
NoneType Builder::getNoneType() { return NoneType::get(context); }
|
NoneType Builder::getNoneType() { return NoneType::get(context); }
|
||||||
|
|
|
@ -179,7 +179,7 @@ FuncOp FuncOp::clone(BlockAndValueMapping &mapper) {
|
||||||
for (unsigned i = 0, e = getNumArguments(); i != e; ++i)
|
for (unsigned i = 0, e = getNumArguments(); i != e; ++i)
|
||||||
if (!mapper.contains(getArgument(i)))
|
if (!mapper.contains(getArgument(i)))
|
||||||
inputTypes.push_back(newType.getInput(i));
|
inputTypes.push_back(newType.getInput(i));
|
||||||
newType = FunctionType::get(inputTypes, newType.getResults(), getContext());
|
newType = FunctionType::get(getContext(), inputTypes, newType.getResults());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create the new function.
|
// Create the new function.
|
||||||
|
|
|
@ -35,7 +35,7 @@ ComplexType ComplexType::get(Type elementType) {
|
||||||
return Base::get(elementType.getContext(), elementType);
|
return Base::get(elementType.getContext(), elementType);
|
||||||
}
|
}
|
||||||
|
|
||||||
ComplexType ComplexType::getChecked(Type elementType, Location location) {
|
ComplexType ComplexType::getChecked(Location location, Type elementType) {
|
||||||
return Base::getChecked(location, elementType);
|
return Base::getChecked(location, elementType);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -76,7 +76,7 @@ IntegerType::SignednessSemantics IntegerType::getSignedness() const {
|
||||||
IntegerType IntegerType::scaleElementBitwidth(unsigned scale) {
|
IntegerType IntegerType::scaleElementBitwidth(unsigned scale) {
|
||||||
if (!scale)
|
if (!scale)
|
||||||
return IntegerType();
|
return IntegerType();
|
||||||
return IntegerType::get(scale * getWidth(), getSignedness(), getContext());
|
return IntegerType::get(getContext(), scale * getWidth(), getSignedness());
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -126,8 +126,8 @@ FloatType FloatType::scaleElementBitwidth(unsigned scale) {
|
||||||
// FunctionType
|
// FunctionType
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
FunctionType FunctionType::get(TypeRange inputs, TypeRange results,
|
FunctionType FunctionType::get(MLIRContext *context, TypeRange inputs,
|
||||||
MLIRContext *context) {
|
TypeRange results) {
|
||||||
return Base::get(context, inputs, results);
|
return Base::get(context, inputs, results);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -182,20 +182,20 @@ FunctionType::getWithoutArgsAndResults(ArrayRef<unsigned> argIndices,
|
||||||
newResultTypes = newResultTypesBuffer;
|
newResultTypes = newResultTypesBuffer;
|
||||||
}
|
}
|
||||||
|
|
||||||
return get(newInputTypes, newResultTypes, getContext());
|
return get(getContext(), newInputTypes, newResultTypes);
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// OpaqueType
|
// OpaqueType
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
OpaqueType OpaqueType::get(Identifier dialect, StringRef typeData,
|
OpaqueType OpaqueType::get(MLIRContext *context, Identifier dialect,
|
||||||
MLIRContext *context) {
|
StringRef typeData) {
|
||||||
return Base::get(context, dialect, typeData);
|
return Base::get(context, dialect, typeData);
|
||||||
}
|
}
|
||||||
|
|
||||||
OpaqueType OpaqueType::getChecked(Identifier dialect, StringRef typeData,
|
OpaqueType OpaqueType::getChecked(Location location, Identifier dialect,
|
||||||
MLIRContext *context, Location location) {
|
StringRef typeData) {
|
||||||
return Base::getChecked(location, dialect, typeData);
|
return Base::getChecked(location, dialect, typeData);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -313,8 +313,8 @@ VectorType VectorType::get(ArrayRef<int64_t> shape, Type elementType) {
|
||||||
return Base::get(elementType.getContext(), shape, elementType);
|
return Base::get(elementType.getContext(), shape, elementType);
|
||||||
}
|
}
|
||||||
|
|
||||||
VectorType VectorType::getChecked(ArrayRef<int64_t> shape, Type elementType,
|
VectorType VectorType::getChecked(Location location, ArrayRef<int64_t> shape,
|
||||||
Location location) {
|
Type elementType) {
|
||||||
return Base::getChecked(location, shape, elementType);
|
return Base::getChecked(location, shape, elementType);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -379,9 +379,9 @@ RankedTensorType RankedTensorType::get(ArrayRef<int64_t> shape,
|
||||||
return Base::get(elementType.getContext(), shape, elementType);
|
return Base::get(elementType.getContext(), shape, elementType);
|
||||||
}
|
}
|
||||||
|
|
||||||
RankedTensorType RankedTensorType::getChecked(ArrayRef<int64_t> shape,
|
RankedTensorType RankedTensorType::getChecked(Location location,
|
||||||
Type elementType,
|
ArrayRef<int64_t> shape,
|
||||||
Location location) {
|
Type elementType) {
|
||||||
return Base::getChecked(location, shape, elementType);
|
return Base::getChecked(location, shape, elementType);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -406,8 +406,8 @@ UnrankedTensorType UnrankedTensorType::get(Type elementType) {
|
||||||
return Base::get(elementType.getContext(), elementType);
|
return Base::get(elementType.getContext(), elementType);
|
||||||
}
|
}
|
||||||
|
|
||||||
UnrankedTensorType UnrankedTensorType::getChecked(Type elementType,
|
UnrankedTensorType UnrankedTensorType::getChecked(Location location,
|
||||||
Location location) {
|
Type elementType) {
|
||||||
return Base::getChecked(location, elementType);
|
return Base::getChecked(location, elementType);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -448,9 +448,10 @@ MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
|
||||||
/// UnknownLoc. If the MemRefType defined by the arguments would be
|
/// UnknownLoc. If the MemRefType defined by the arguments would be
|
||||||
/// ill-formed, emits errors (to the handler registered with the context or to
|
/// ill-formed, emits errors (to the handler registered with the context or to
|
||||||
/// the error stream) and returns nullptr.
|
/// the error stream) and returns nullptr.
|
||||||
MemRefType MemRefType::getChecked(ArrayRef<int64_t> shape, Type elementType,
|
MemRefType MemRefType::getChecked(Location location, ArrayRef<int64_t> shape,
|
||||||
|
Type elementType,
|
||||||
ArrayRef<AffineMap> affineMapComposition,
|
ArrayRef<AffineMap> affineMapComposition,
|
||||||
unsigned memorySpace, Location location) {
|
unsigned memorySpace) {
|
||||||
return getImpl(shape, elementType, affineMapComposition, memorySpace,
|
return getImpl(shape, elementType, affineMapComposition, memorySpace,
|
||||||
location);
|
location);
|
||||||
}
|
}
|
||||||
|
@ -524,9 +525,9 @@ UnrankedMemRefType UnrankedMemRefType::get(Type elementType,
|
||||||
return Base::get(elementType.getContext(), elementType, memorySpace);
|
return Base::get(elementType.getContext(), elementType, memorySpace);
|
||||||
}
|
}
|
||||||
|
|
||||||
UnrankedMemRefType UnrankedMemRefType::getChecked(Type elementType,
|
UnrankedMemRefType UnrankedMemRefType::getChecked(Location location,
|
||||||
unsigned memorySpace,
|
Type elementType,
|
||||||
Location location) {
|
unsigned memorySpace) {
|
||||||
return Base::getChecked(location, elementType, memorySpace);
|
return Base::getChecked(location, elementType, memorySpace);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -694,12 +695,12 @@ LogicalResult mlir::getStridesAndOffset(MemRefType t,
|
||||||
|
|
||||||
/// Get or create a new TupleType with the provided element types. Assumes the
|
/// Get or create a new TupleType with the provided element types. Assumes the
|
||||||
/// arguments define a well-formed type.
|
/// arguments define a well-formed type.
|
||||||
TupleType TupleType::get(TypeRange elementTypes, MLIRContext *context) {
|
TupleType TupleType::get(MLIRContext *context, TypeRange elementTypes) {
|
||||||
return Base::get(context, elementTypes);
|
return Base::get(context, elementTypes);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get or create an empty tuple type.
|
/// Get or create an empty tuple type.
|
||||||
TupleType TupleType::get(MLIRContext *context) { return get({}, context); }
|
TupleType TupleType::get(MLIRContext *context) { return get(context, {}); }
|
||||||
|
|
||||||
/// Return the elements types for this tuple.
|
/// Return the elements types for this tuple.
|
||||||
ArrayRef<Type> TupleType::getTypes() const { return getImpl()->getTypes(); }
|
ArrayRef<Type> TupleType::getTypes() const { return getImpl()->getTypes(); }
|
||||||
|
|
|
@ -82,7 +82,7 @@ Type Dialect::parseType(DialectAsmParser &parser) const {
|
||||||
// If this dialect allows unknown types, then represent this with OpaqueType.
|
// If this dialect allows unknown types, then represent this with OpaqueType.
|
||||||
if (allowsUnknownTypes()) {
|
if (allowsUnknownTypes()) {
|
||||||
auto ns = Identifier::get(getNamespace(), getContext());
|
auto ns = Identifier::get(getNamespace(), getContext());
|
||||||
return OpaqueType::get(ns, parser.getFullSymbolSpec(), getContext());
|
return OpaqueType::get(getContext(), ns, parser.getFullSymbolSpec());
|
||||||
}
|
}
|
||||||
|
|
||||||
parser.emitError(parser.getNameLoc())
|
parser.emitError(parser.getNameLoc())
|
||||||
|
|
|
@ -772,25 +772,23 @@ getCachedIntegerType(unsigned width,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
IntegerType IntegerType::get(unsigned width, MLIRContext *context) {
|
IntegerType IntegerType::get(MLIRContext *context, unsigned width) {
|
||||||
return get(width, IntegerType::Signless, context);
|
return get(context, width, IntegerType::Signless);
|
||||||
}
|
}
|
||||||
|
|
||||||
IntegerType IntegerType::get(unsigned width,
|
IntegerType IntegerType::get(MLIRContext *context, unsigned width,
|
||||||
IntegerType::SignednessSemantics signedness,
|
IntegerType::SignednessSemantics signedness) {
|
||||||
MLIRContext *context) {
|
|
||||||
if (auto cached = getCachedIntegerType(width, signedness, context))
|
if (auto cached = getCachedIntegerType(width, signedness, context))
|
||||||
return cached;
|
return cached;
|
||||||
return Base::get(context, width, signedness);
|
return Base::get(context, width, signedness);
|
||||||
}
|
}
|
||||||
|
|
||||||
IntegerType IntegerType::getChecked(unsigned width, Location location) {
|
IntegerType IntegerType::getChecked(Location location, unsigned width) {
|
||||||
return getChecked(width, IntegerType::Signless, location);
|
return getChecked(location, width, IntegerType::Signless);
|
||||||
}
|
}
|
||||||
|
|
||||||
IntegerType IntegerType::getChecked(unsigned width,
|
IntegerType IntegerType::getChecked(Location location, unsigned width,
|
||||||
SignednessSemantics signedness,
|
SignednessSemantics signedness) {
|
||||||
Location location) {
|
|
||||||
if (auto cached =
|
if (auto cached =
|
||||||
getCachedIntegerType(width, signedness, location->getContext()))
|
getCachedIntegerType(width, signedness, location->getContext()))
|
||||||
return cached;
|
return cached;
|
||||||
|
|
|
@ -178,7 +178,7 @@ Operation::Operation(Location location, OperationName name,
|
||||||
if (hasSingleResult)
|
if (hasSingleResult)
|
||||||
resultType = resultTypes.front();
|
resultType = resultTypes.front();
|
||||||
else
|
else
|
||||||
resultType = TupleType::get(resultTypes, location->getContext());
|
resultType = TupleType::get(location->getContext(), resultTypes);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -63,7 +63,7 @@ void Value::setType(Type newType) {
|
||||||
return;
|
return;
|
||||||
auto newTypes = llvm::to_vector<4>(curTypes);
|
auto newTypes = llvm::to_vector<4>(curTypes);
|
||||||
newTypes[resultNo] = newType;
|
newTypes[resultNo] = newType;
|
||||||
owner->resultType = TupleType::get(newTypes, newType.getContext());
|
owner->resultType = TupleType::get(newType.getContext(), newTypes);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// If this value is the result of an Operation, return the operation that
|
/// If this value is the result of an Operation, return the operation that
|
||||||
|
|
|
@ -563,8 +563,8 @@ Type Parser::parseExtendedType() {
|
||||||
|
|
||||||
// Otherwise, form a new opaque type.
|
// Otherwise, form a new opaque type.
|
||||||
return OpaqueType::getChecked(
|
return OpaqueType::getChecked(
|
||||||
Identifier::get(dialectName, state.context), symbolData,
|
getEncodedSourceLocation(loc),
|
||||||
state.context, getEncodedSourceLocation(loc));
|
Identifier::get(dialectName, state.context), symbolData);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -338,7 +338,7 @@ Type Parser::parseNonFunctionType() {
|
||||||
signSemantics = *signedness ? IntegerType::Signed : IntegerType::Unsigned;
|
signSemantics = *signedness ? IntegerType::Signed : IntegerType::Unsigned;
|
||||||
|
|
||||||
consumeToken(Token::inttype);
|
consumeToken(Token::inttype);
|
||||||
return IntegerType::get(width.getValue(), signSemantics, getContext());
|
return IntegerType::get(getContext(), width.getValue(), signSemantics);
|
||||||
}
|
}
|
||||||
|
|
||||||
// float-type
|
// float-type
|
||||||
|
@ -432,7 +432,7 @@ Type Parser::parseTupleType() {
|
||||||
parseToken(Token::greater, "expected '>' in tuple type"))
|
parseToken(Token::greater, "expected '>' in tuple type"))
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
|
||||||
return TupleType::get(types, getContext());
|
return TupleType::get(getContext(), types);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Parse a vector type.
|
/// Parse a vector type.
|
||||||
|
|
|
@ -236,7 +236,7 @@ Type Importer::getStdTypeForAttr(LLVMType type) {
|
||||||
Attribute Importer::getConstantAsAttr(llvm::Constant *value) {
|
Attribute Importer::getConstantAsAttr(llvm::Constant *value) {
|
||||||
if (auto *ci = dyn_cast<llvm::ConstantInt>(value))
|
if (auto *ci = dyn_cast<llvm::ConstantInt>(value))
|
||||||
return b.getIntegerAttr(
|
return b.getIntegerAttr(
|
||||||
IntegerType::get(ci->getType()->getBitWidth(), context),
|
IntegerType::get(context, ci->getType()->getBitWidth()),
|
||||||
ci->getValue());
|
ci->getValue());
|
||||||
if (auto *c = dyn_cast<llvm::ConstantDataArray>(value))
|
if (auto *c = dyn_cast<llvm::ConstantDataArray>(value))
|
||||||
if (c->isString())
|
if (c->isString())
|
||||||
|
|
|
@ -1182,7 +1182,7 @@ LogicalResult Deserializer::processType(spirv::Opcode opcode,
|
||||||
// signless semantics for such cases.
|
// signless semantics for such cases.
|
||||||
auto sign = operands[2] == 1 ? IntegerType::SignednessSemantics::Signed
|
auto sign = operands[2] == 1 ? IntegerType::SignednessSemantics::Signed
|
||||||
: IntegerType::SignednessSemantics::Signless;
|
: IntegerType::SignednessSemantics::Signless;
|
||||||
typeMap[operands[0]] = IntegerType::get(operands[1], sign, context);
|
typeMap[operands[0]] = IntegerType::get(context, operands[1], sign);
|
||||||
} break;
|
} break;
|
||||||
case spirv::Opcode::OpTypeFloat: {
|
case spirv::Opcode::OpTypeFloat: {
|
||||||
if (operands.size() != 2)
|
if (operands.size() != 2)
|
||||||
|
@ -1345,7 +1345,7 @@ LogicalResult Deserializer::processFunctionType(ArrayRef<uint32_t> operands) {
|
||||||
if (!isVoidType(returnType)) {
|
if (!isVoidType(returnType)) {
|
||||||
returnTypes = llvm::makeArrayRef(returnType);
|
returnTypes = llvm::makeArrayRef(returnType);
|
||||||
}
|
}
|
||||||
typeMap[operands[0]] = FunctionType::get(argTypes, returnTypes, context);
|
typeMap[operands[0]] = FunctionType::get(context, argTypes, returnTypes);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1267,7 +1267,7 @@ LogicalResult Serializer::prepareBasicType(
|
||||||
}
|
}
|
||||||
typeEnum = spirv::Opcode::OpTypeCooperativeMatrixNV;
|
typeEnum = spirv::Opcode::OpTypeCooperativeMatrixNV;
|
||||||
auto getConstantOp = [&](uint32_t id) {
|
auto getConstantOp = [&](uint32_t id) {
|
||||||
auto attr = IntegerAttr::get(IntegerType::get(32, type.getContext()), id);
|
auto attr = IntegerAttr::get(IntegerType::get(type.getContext(), 32), id);
|
||||||
return prepareConstantInt(loc, attr);
|
return prepareConstantInt(loc, attr);
|
||||||
};
|
};
|
||||||
operands.push_back(elementTypeID);
|
operands.push_back(elementTypeID);
|
||||||
|
|
|
@ -35,8 +35,8 @@ static void updateFuncOp(FuncOp func,
|
||||||
// Add the new arguments to the function type.
|
// Add the new arguments to the function type.
|
||||||
auto newArgTypes = llvm::to_vector<6>(
|
auto newArgTypes = llvm::to_vector<6>(
|
||||||
llvm::concat<const Type>(functionType.getInputs(), erasedResultTypes));
|
llvm::concat<const Type>(functionType.getInputs(), erasedResultTypes));
|
||||||
auto newFunctionType = FunctionType::get(
|
auto newFunctionType = FunctionType::get(func.getContext(), newArgTypes,
|
||||||
newArgTypes, functionType.getResults(), func.getContext());
|
functionType.getResults());
|
||||||
func.setType(newFunctionType);
|
func.setType(newFunctionType);
|
||||||
|
|
||||||
// Transfer the result attributes to arg attributes.
|
// Transfer the result attributes to arg attributes.
|
||||||
|
|
|
@ -230,9 +230,8 @@ void NormalizeMemRefs::updateFunctionSignature(FuncOp funcOp,
|
||||||
|
|
||||||
// We create a new function type and modify the function signature with this
|
// We create a new function type and modify the function signature with this
|
||||||
// new type.
|
// new type.
|
||||||
newFuncType = FunctionType::get(/*inputs=*/argTypes,
|
newFuncType = FunctionType::get(&getContext(), /*inputs=*/argTypes,
|
||||||
/*results=*/resultTypes,
|
/*results=*/resultTypes);
|
||||||
/*context=*/&getContext());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Since we update the function signature, it might affect the result types at
|
// Since we update the function signature, it might affect the result types at
|
||||||
|
@ -463,9 +462,9 @@ void NormalizeMemRefs::normalizeFuncOpMemRefs(FuncOp funcOp,
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
FunctionType newFuncType = FunctionType::get(/*inputs=*/inputTypes,
|
FunctionType newFuncType =
|
||||||
/*results=*/resultTypes,
|
FunctionType::get(&getContext(), /*inputs=*/inputTypes,
|
||||||
/*context=*/&getContext());
|
/*results=*/resultTypes);
|
||||||
// Setting the new function signature for this external function.
|
// Setting the new function signature for this external function.
|
||||||
funcOp.setType(newFuncType);
|
funcOp.setType(newFuncType);
|
||||||
}
|
}
|
||||||
|
|
|
@ -2522,8 +2522,8 @@ struct FuncOpSignatureConversion : public OpConversionPattern<FuncOp> {
|
||||||
|
|
||||||
// Update the function signature in-place.
|
// Update the function signature in-place.
|
||||||
rewriter.updateRootInPlace(funcOp, [&] {
|
rewriter.updateRootInPlace(funcOp, [&] {
|
||||||
funcOp.setType(FunctionType::get(result.getConvertedTypes(), newResults,
|
funcOp.setType(FunctionType::get(funcOp.getContext(),
|
||||||
funcOp.getContext()));
|
result.getConvertedTypes(), newResults));
|
||||||
});
|
});
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
|
@ -56,7 +56,7 @@ static FuncOp makeFunction(StringRef name, ArrayRef<Type> results = {},
|
||||||
ArrayRef<Type> args = {}) {
|
ArrayRef<Type> args = {}) {
|
||||||
auto &ctx = globalContext();
|
auto &ctx = globalContext();
|
||||||
auto function = FuncOp::create(UnknownLoc::get(&ctx), name,
|
auto function = FuncOp::create(UnknownLoc::get(&ctx), name,
|
||||||
FunctionType::get(args, results, &ctx));
|
FunctionType::get(&ctx, args, results));
|
||||||
function.addEntryBlock();
|
function.addEntryBlock();
|
||||||
return function;
|
return function;
|
||||||
}
|
}
|
||||||
|
@ -277,7 +277,7 @@ TEST_FUNC(builder_blocks) {
|
||||||
|
|
||||||
TEST_FUNC(builder_cond_branch) {
|
TEST_FUNC(builder_cond_branch) {
|
||||||
auto f = makeFunction("builder_cond_branch", {},
|
auto f = makeFunction("builder_cond_branch", {},
|
||||||
{IntegerType::get(1, &globalContext())});
|
{IntegerType::get(&globalContext(), 1)});
|
||||||
|
|
||||||
OpBuilder builder(f.getBody());
|
OpBuilder builder(f.getBody());
|
||||||
ScopedContext scope(builder, f.getLoc());
|
ScopedContext scope(builder, f.getLoc());
|
||||||
|
@ -390,8 +390,8 @@ TEST_FUNC(insertion_in_block) {
|
||||||
|
|
||||||
TEST_FUNC(zero_and_std_sign_extendi_op_i1_to_i8) {
|
TEST_FUNC(zero_and_std_sign_extendi_op_i1_to_i8) {
|
||||||
using namespace edsc::op;
|
using namespace edsc::op;
|
||||||
auto i1Type = IntegerType::get(1, &globalContext());
|
auto i1Type = IntegerType::get(&globalContext(), 1);
|
||||||
auto i8Type = IntegerType::get(8, &globalContext());
|
auto i8Type = IntegerType::get(&globalContext(), 8);
|
||||||
auto memrefType = MemRefType::get({}, i1Type, {}, 0);
|
auto memrefType = MemRefType::get({}, i1Type, {}, 0);
|
||||||
auto f = makeFunction("zero_and_std_sign_extendi_op", {},
|
auto f = makeFunction("zero_and_std_sign_extendi_op", {},
|
||||||
{memrefType, memrefType});
|
{memrefType, memrefType});
|
||||||
|
@ -414,7 +414,7 @@ TEST_FUNC(zero_and_std_sign_extendi_op_i1_to_i8) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_FUNC(operator_or) {
|
TEST_FUNC(operator_or) {
|
||||||
auto i1Type = IntegerType::get(/*width=*/1, &globalContext());
|
auto i1Type = IntegerType::get(&globalContext(), /*width=*/1);
|
||||||
auto f = makeFunction("operator_or", {}, {i1Type, i1Type});
|
auto f = makeFunction("operator_or", {}, {i1Type, i1Type});
|
||||||
|
|
||||||
OpBuilder builder(f.getBody());
|
OpBuilder builder(f.getBody());
|
||||||
|
@ -435,7 +435,7 @@ TEST_FUNC(operator_or) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_FUNC(operator_and) {
|
TEST_FUNC(operator_and) {
|
||||||
auto i1Type = IntegerType::get(/*width=*/1, &globalContext());
|
auto i1Type = IntegerType::get(&globalContext(), /*width=*/1);
|
||||||
auto f = makeFunction("operator_and", {}, {i1Type, i1Type});
|
auto f = makeFunction("operator_and", {}, {i1Type, i1Type});
|
||||||
|
|
||||||
OpBuilder builder(f.getBody());
|
OpBuilder builder(f.getBody());
|
||||||
|
@ -536,7 +536,7 @@ TEST_FUNC(fptrunc_f32_bf16) {
|
||||||
|
|
||||||
TEST_FUNC(select_op_i32) {
|
TEST_FUNC(select_op_i32) {
|
||||||
using namespace edsc::op;
|
using namespace edsc::op;
|
||||||
auto i32Type = IntegerType::get(32, &globalContext());
|
auto i32Type = IntegerType::get(&globalContext(), 32);
|
||||||
auto memrefType = MemRefType::get(
|
auto memrefType = MemRefType::get(
|
||||||
{ShapedType::kDynamicSize, ShapedType::kDynamicSize}, i32Type, {}, 0);
|
{ShapedType::kDynamicSize, ShapedType::kDynamicSize}, i32Type, {}, 0);
|
||||||
auto f = makeFunction("select_op", {}, {memrefType});
|
auto f = makeFunction("select_op", {}, {memrefType});
|
||||||
|
|
|
@ -653,7 +653,7 @@ LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents(
|
||||||
}
|
}
|
||||||
int64_t dim =
|
int64_t dim =
|
||||||
sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamicSize;
|
sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamicSize;
|
||||||
auto type = IntegerType::get(17, context);
|
auto type = IntegerType::get(context, 17);
|
||||||
inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type));
|
inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type));
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
|
@ -509,7 +509,7 @@ struct TestTypeConverter : public TypeConverter {
|
||||||
|
|
||||||
// Convert I42 to I43.
|
// Convert I42 to I43.
|
||||||
if (t.isInteger(42)) {
|
if (t.isInteger(42)) {
|
||||||
results.push_back(IntegerType::get(43, t.getContext()));
|
results.push_back(IntegerType::get(t.getContext(), 43));
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -69,9 +69,7 @@ struct TestDecomposeCallGraphTypes
|
||||||
Location loc) -> Optional<Value> {
|
Location loc) -> Optional<Value> {
|
||||||
if (inputs.size() == 1)
|
if (inputs.size() == 1)
|
||||||
return llvm::None;
|
return llvm::None;
|
||||||
TypeRange TypeRange = inputs.getTypes();
|
TupleType tuple = builder.getTupleType(inputs.getTypes());
|
||||||
SmallVector<Type, 2> types(TypeRange.begin(), TypeRange.end());
|
|
||||||
TupleType tuple = TupleType::get(types, builder.getContext());
|
|
||||||
Value value = builder.create<test::MakeTupleOp>(loc, tuple, inputs);
|
Value value = builder.create<test::MakeTupleOp>(loc, tuple, inputs);
|
||||||
return value;
|
return value;
|
||||||
});
|
});
|
||||||
|
|
|
@ -59,7 +59,7 @@ ElementsAttr getTestSparseElementsAttr(MLIRContext *ctx,
|
||||||
} else {
|
} else {
|
||||||
tensorType = RankedTensorType::get(shape, eleType);
|
tensorType = RankedTensorType::get(shape, eleType);
|
||||||
}
|
}
|
||||||
auto indicesType = RankedTensorType::get({1, 2}, IntegerType::get(64, ctx));
|
auto indicesType = RankedTensorType::get({1, 2}, IntegerType::get(ctx, 64));
|
||||||
auto indices =
|
auto indices =
|
||||||
DenseIntElementsAttr::get(indicesType, {APInt(64, 0), APInt(64, 0)});
|
DenseIntElementsAttr::get(indicesType, {APInt(64, 0), APInt(64, 0)});
|
||||||
auto valuesType = RankedTensorType::get({1}, eleType);
|
auto valuesType = RankedTensorType::get({1}, eleType);
|
||||||
|
@ -77,7 +77,7 @@ UniformQuantizedType getTestQuantizedType(Type storageType, MLIRContext *ctx) {
|
||||||
TEST(QuantizationUtilsTest, convertFloatAttrUniform) {
|
TEST(QuantizationUtilsTest, convertFloatAttrUniform) {
|
||||||
MLIRContext ctx;
|
MLIRContext ctx;
|
||||||
ctx.getOrLoadDialect<QuantizationDialect>();
|
ctx.getOrLoadDialect<QuantizationDialect>();
|
||||||
IntegerType convertedType = IntegerType::get(8, &ctx);
|
IntegerType convertedType = IntegerType::get(&ctx, 8);
|
||||||
auto quantizedType = getTestQuantizedType(convertedType, &ctx);
|
auto quantizedType = getTestQuantizedType(convertedType, &ctx);
|
||||||
TestUniformQuantizedValueConverter converter(quantizedType);
|
TestUniformQuantizedValueConverter converter(quantizedType);
|
||||||
|
|
||||||
|
@ -95,7 +95,7 @@ TEST(QuantizationUtilsTest, convertFloatAttrUniform) {
|
||||||
TEST(QuantizationUtilsTest, convertRankedDenseAttrUniform) {
|
TEST(QuantizationUtilsTest, convertRankedDenseAttrUniform) {
|
||||||
MLIRContext ctx;
|
MLIRContext ctx;
|
||||||
ctx.getOrLoadDialect<QuantizationDialect>();
|
ctx.getOrLoadDialect<QuantizationDialect>();
|
||||||
IntegerType convertedType = IntegerType::get(8, &ctx);
|
IntegerType convertedType = IntegerType::get(&ctx, 8);
|
||||||
auto quantizedType = getTestQuantizedType(convertedType, &ctx);
|
auto quantizedType = getTestQuantizedType(convertedType, &ctx);
|
||||||
TestUniformQuantizedValueConverter converter(quantizedType);
|
TestUniformQuantizedValueConverter converter(quantizedType);
|
||||||
auto realValue = getTestElementsAttr<DenseElementsAttr, ArrayRef<Attribute>>(
|
auto realValue = getTestElementsAttr<DenseElementsAttr, ArrayRef<Attribute>>(
|
||||||
|
@ -120,7 +120,7 @@ TEST(QuantizationUtilsTest, convertRankedDenseAttrUniform) {
|
||||||
TEST(QuantizationUtilsTest, convertRankedSplatAttrUniform) {
|
TEST(QuantizationUtilsTest, convertRankedSplatAttrUniform) {
|
||||||
MLIRContext ctx;
|
MLIRContext ctx;
|
||||||
ctx.getOrLoadDialect<QuantizationDialect>();
|
ctx.getOrLoadDialect<QuantizationDialect>();
|
||||||
IntegerType convertedType = IntegerType::get(8, &ctx);
|
IntegerType convertedType = IntegerType::get(&ctx, 8);
|
||||||
auto quantizedType = getTestQuantizedType(convertedType, &ctx);
|
auto quantizedType = getTestQuantizedType(convertedType, &ctx);
|
||||||
TestUniformQuantizedValueConverter converter(quantizedType);
|
TestUniformQuantizedValueConverter converter(quantizedType);
|
||||||
auto realValue = getTestElementsAttr<DenseElementsAttr, Attribute>(
|
auto realValue = getTestElementsAttr<DenseElementsAttr, Attribute>(
|
||||||
|
@ -145,7 +145,7 @@ TEST(QuantizationUtilsTest, convertRankedSplatAttrUniform) {
|
||||||
TEST(QuantizationUtilsTest, convertRankedSparseAttrUniform) {
|
TEST(QuantizationUtilsTest, convertRankedSparseAttrUniform) {
|
||||||
MLIRContext ctx;
|
MLIRContext ctx;
|
||||||
ctx.getOrLoadDialect<QuantizationDialect>();
|
ctx.getOrLoadDialect<QuantizationDialect>();
|
||||||
IntegerType convertedType = IntegerType::get(8, &ctx);
|
IntegerType convertedType = IntegerType::get(&ctx, 8);
|
||||||
auto quantizedType = getTestQuantizedType(convertedType, &ctx);
|
auto quantizedType = getTestQuantizedType(convertedType, &ctx);
|
||||||
TestUniformQuantizedValueConverter converter(quantizedType);
|
TestUniformQuantizedValueConverter converter(quantizedType);
|
||||||
auto realValue = getTestSparseElementsAttr(&ctx, {1, 2});
|
auto realValue = getTestSparseElementsAttr(&ctx, {1, 2});
|
||||||
|
|
|
@ -33,7 +33,7 @@ static void testSplat(Type eltType, const EltTy &splatElt) {
|
||||||
namespace {
|
namespace {
|
||||||
TEST(DenseSplatTest, BoolSplat) {
|
TEST(DenseSplatTest, BoolSplat) {
|
||||||
MLIRContext context;
|
MLIRContext context;
|
||||||
IntegerType boolTy = IntegerType::get(1, &context);
|
IntegerType boolTy = IntegerType::get(&context, 1);
|
||||||
RankedTensorType shape = RankedTensorType::get({2, 2}, boolTy);
|
RankedTensorType shape = RankedTensorType::get({2, 2}, boolTy);
|
||||||
|
|
||||||
// Check that splat is automatically detected for boolean values.
|
// Check that splat is automatically detected for boolean values.
|
||||||
|
@ -58,7 +58,7 @@ TEST(DenseSplatTest, LargeBoolSplat) {
|
||||||
constexpr int64_t boolCount = 56;
|
constexpr int64_t boolCount = 56;
|
||||||
|
|
||||||
MLIRContext context;
|
MLIRContext context;
|
||||||
IntegerType boolTy = IntegerType::get(1, &context);
|
IntegerType boolTy = IntegerType::get(&context, 1);
|
||||||
RankedTensorType shape = RankedTensorType::get({boolCount}, boolTy);
|
RankedTensorType shape = RankedTensorType::get({boolCount}, boolTy);
|
||||||
|
|
||||||
// Check that splat is automatically detected for boolean values.
|
// Check that splat is automatically detected for boolean values.
|
||||||
|
@ -81,7 +81,7 @@ TEST(DenseSplatTest, LargeBoolSplat) {
|
||||||
|
|
||||||
TEST(DenseSplatTest, BoolNonSplat) {
|
TEST(DenseSplatTest, BoolNonSplat) {
|
||||||
MLIRContext context;
|
MLIRContext context;
|
||||||
IntegerType boolTy = IntegerType::get(1, &context);
|
IntegerType boolTy = IntegerType::get(&context, 1);
|
||||||
RankedTensorType shape = RankedTensorType::get({6}, boolTy);
|
RankedTensorType shape = RankedTensorType::get({6}, boolTy);
|
||||||
|
|
||||||
// Check that we properly handle non-splat values.
|
// Check that we properly handle non-splat values.
|
||||||
|
@ -94,7 +94,7 @@ TEST(DenseSplatTest, OddIntSplat) {
|
||||||
// Test detecting a splat with an odd(non 8-bit) integer bitwidth.
|
// Test detecting a splat with an odd(non 8-bit) integer bitwidth.
|
||||||
MLIRContext context;
|
MLIRContext context;
|
||||||
constexpr size_t intWidth = 19;
|
constexpr size_t intWidth = 19;
|
||||||
IntegerType intTy = IntegerType::get(intWidth, &context);
|
IntegerType intTy = IntegerType::get(&context, intWidth);
|
||||||
APInt value(intWidth, 10);
|
APInt value(intWidth, 10);
|
||||||
|
|
||||||
testSplat(intTy, value);
|
testSplat(intTy, value);
|
||||||
|
@ -102,7 +102,7 @@ TEST(DenseSplatTest, OddIntSplat) {
|
||||||
|
|
||||||
TEST(DenseSplatTest, Int32Splat) {
|
TEST(DenseSplatTest, Int32Splat) {
|
||||||
MLIRContext context;
|
MLIRContext context;
|
||||||
IntegerType intTy = IntegerType::get(32, &context);
|
IntegerType intTy = IntegerType::get(&context, 32);
|
||||||
int value = 64;
|
int value = 64;
|
||||||
|
|
||||||
testSplat(intTy, value);
|
testSplat(intTy, value);
|
||||||
|
@ -110,7 +110,7 @@ TEST(DenseSplatTest, Int32Splat) {
|
||||||
|
|
||||||
TEST(DenseSplatTest, IntAttrSplat) {
|
TEST(DenseSplatTest, IntAttrSplat) {
|
||||||
MLIRContext context;
|
MLIRContext context;
|
||||||
IntegerType intTy = IntegerType::get(85, &context);
|
IntegerType intTy = IntegerType::get(&context, 85);
|
||||||
Attribute value = IntegerAttr::get(intTy, 109);
|
Attribute value = IntegerAttr::get(intTy, 109);
|
||||||
|
|
||||||
testSplat(intTy, value);
|
testSplat(intTy, value);
|
||||||
|
@ -151,7 +151,7 @@ TEST(DenseSplatTest, BF16Splat) {
|
||||||
TEST(DenseSplatTest, StringSplat) {
|
TEST(DenseSplatTest, StringSplat) {
|
||||||
MLIRContext context;
|
MLIRContext context;
|
||||||
Type stringType =
|
Type stringType =
|
||||||
OpaqueType::get(Identifier::get("test", &context), "string", &context);
|
OpaqueType::get(&context, Identifier::get("test", &context), "string");
|
||||||
StringRef value = "test-string";
|
StringRef value = "test-string";
|
||||||
testSplat(stringType, value);
|
testSplat(stringType, value);
|
||||||
}
|
}
|
||||||
|
@ -159,7 +159,7 @@ TEST(DenseSplatTest, StringSplat) {
|
||||||
TEST(DenseSplatTest, StringAttrSplat) {
|
TEST(DenseSplatTest, StringAttrSplat) {
|
||||||
MLIRContext context;
|
MLIRContext context;
|
||||||
Type stringType =
|
Type stringType =
|
||||||
OpaqueType::get(Identifier::get("test", &context), "string", &context);
|
OpaqueType::get(&context, Identifier::get("test", &context), "string");
|
||||||
Attribute stringAttr = StringAttr::get("test-string", stringType);
|
Attribute stringAttr = StringAttr::get("test-string", stringType);
|
||||||
testSplat(stringType, stringAttr);
|
testSplat(stringType, stringAttr);
|
||||||
}
|
}
|
||||||
|
@ -173,7 +173,7 @@ TEST(DenseComplexTest, ComplexFloatSplat) {
|
||||||
|
|
||||||
TEST(DenseComplexTest, ComplexIntSplat) {
|
TEST(DenseComplexTest, ComplexIntSplat) {
|
||||||
MLIRContext context;
|
MLIRContext context;
|
||||||
ComplexType complexType = ComplexType::get(IntegerType::get(64, &context));
|
ComplexType complexType = ComplexType::get(IntegerType::get(&context, 64));
|
||||||
std::complex<int64_t> value(10, 15);
|
std::complex<int64_t> value(10, 15);
|
||||||
testSplat(complexType, value);
|
testSplat(complexType, value);
|
||||||
}
|
}
|
||||||
|
@ -187,7 +187,7 @@ TEST(DenseComplexTest, ComplexAPFloatSplat) {
|
||||||
|
|
||||||
TEST(DenseComplexTest, ComplexAPIntSplat) {
|
TEST(DenseComplexTest, ComplexAPIntSplat) {
|
||||||
MLIRContext context;
|
MLIRContext context;
|
||||||
ComplexType complexType = ComplexType::get(IntegerType::get(64, &context));
|
ComplexType complexType = ComplexType::get(IntegerType::get(&context, 64));
|
||||||
std::complex<APInt> value(APInt(64, 10), APInt(64, 15));
|
std::complex<APInt> value(APInt(64, 10), APInt(64, 15));
|
||||||
testSplat(complexType, value);
|
testSplat(complexType, value);
|
||||||
}
|
}
|
||||||
|
|
|
@ -25,7 +25,7 @@ namespace mlir {
|
||||||
/// Helper that returns an example test::TestStruct for testing its
|
/// Helper that returns an example test::TestStruct for testing its
|
||||||
/// implementation.
|
/// implementation.
|
||||||
static test::TestStruct getTestStruct(mlir::MLIRContext *context) {
|
static test::TestStruct getTestStruct(mlir::MLIRContext *context) {
|
||||||
auto integerType = mlir::IntegerType::get(32, context);
|
auto integerType = mlir::IntegerType::get(context, 32);
|
||||||
auto integerAttr = mlir::IntegerAttr::get(integerType, 127);
|
auto integerAttr = mlir::IntegerAttr::get(integerType, 127);
|
||||||
|
|
||||||
auto floatType = mlir::FloatType::getF32(context);
|
auto floatType = mlir::FloatType::getF32(context);
|
||||||
|
@ -105,7 +105,7 @@ TEST(StructsGenTest, ClassofBadTypeFalse) {
|
||||||
expectedValues.begin(), expectedValues.end() - 1);
|
expectedValues.begin(), expectedValues.end() - 1);
|
||||||
|
|
||||||
// Add a copy of the last attribute with the wrong type.
|
// Add a copy of the last attribute with the wrong type.
|
||||||
auto i64Type = mlir::IntegerType::get(64, &context);
|
auto i64Type = mlir::IntegerType::get(&context, 64);
|
||||||
auto elementsType = mlir::RankedTensorType::get({3}, i64Type);
|
auto elementsType = mlir::RankedTensorType::get({3}, i64Type);
|
||||||
auto elementsAttr =
|
auto elementsAttr =
|
||||||
mlir::DenseIntElementsAttr::get(elementsType, ArrayRef<int64_t>{1, 2, 3});
|
mlir::DenseIntElementsAttr::get(elementsType, ArrayRef<int64_t>{1, 2, 3});
|
||||||
|
|
Loading…
Reference in New Issue