Revert "Reorder MLIRContext location in BuiltinAttributes.h"

This reverts commit 7827753f98.
This commit is contained in:
Tres Popp 2021-02-08 09:32:27 +01:00
parent 7827753f98
commit 511dd4f438
33 changed files with 140 additions and 139 deletions

View File

@ -267,27 +267,27 @@ class fir_AllocatableOp<string mnemonic, list<OpTrait> traits = []> :
static constexpr llvm::StringRef inType() { return "in_type"; }
static constexpr llvm::StringRef lenpName() { return "len_param_count"; }
mlir::Type getAllocatedType();
bool hasLenParams() { return bool{(*this)->getAttr(lenpName())}; }
unsigned numLenParams() {
if (auto val = (*this)->getAttrOfType<mlir::IntegerAttr>(lenpName()))
return val.getInt();
return 0;
}
operand_range getLenParams() {
return {operand_begin(), operand_begin() + numLenParams()};
}
unsigned numShapeOperands() {
return operand_end() - operand_begin() + numLenParams();
}
operand_range getShapeOperands() {
return {operand_begin() + numLenParams(), operand_end()};
}
static mlir::Type getRefTy(mlir::Type ty);
/// Get the input type of the allocation
@ -1131,7 +1131,7 @@ def fir_EmboxCharOp : fir_Op<"emboxchar", [NoSideEffect]> {
}];
let arguments = (ins AnyReferenceLike:$memref, AnyIntegerLike:$len);
let results = (outs fir_BoxCharType);
let assemblyFormat = [{
@ -1563,7 +1563,7 @@ def fir_CoordinateOp : fir_Op<"coordinate_of", [NoSideEffect]> {
p.printFunctionalType((*this)->getOperandTypes(),
(*this)->getResultTypes());
}];
let verifier = [{
auto refTy = ref().getType();
if (fir::isa_ref_type(refTy)) {
@ -1598,7 +1598,7 @@ def fir_CoordinateOp : fir_Op<"coordinate_of", [NoSideEffect]> {
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
OpBuilderDAG<(ins "Type":$type, "ValueRange":$operands,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>];
let extraClassDeclaration = [{
static constexpr llvm::StringRef baseType() { return "base_type"; }
mlir::Type getBaseType();
@ -1686,7 +1686,7 @@ def fir_FieldIndexOp : fir_OneResultOp<"field_index", [NoSideEffect]> {
let printer = [{
p << getOperationName() << ' '
<< (*this)->getAttrOfType<mlir::StringAttr>(fieldAttrName()).getValue()
<< (*this)->getAttrOfType<mlir::StringAttr>(fieldAttrName()).getValue()
<< ", " << (*this)->getAttr(typeAttrName());
if (getNumOperands()) {
p << '(';
@ -2007,7 +2007,7 @@ def fir_IterWhileOp : region_Op<"iterate_while",
CArg<"ValueRange", "llvm::None">:$iterArgs,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>
];
let extraClassDeclaration = [{
mlir::Block *getBody() { return &region().front(); }
mlir::Value getIterateVar() { return getBody()->getArgument(1); }
@ -2276,11 +2276,11 @@ def fir_ConstfOp : fir_Op<"constf", [NoSideEffect]> {
}];
let arguments = (ins FirRealAttr:$constant);
let results = (outs fir_RealType:$res);
let assemblyFormat = "`(` $constant `)` attr-dict `:` type($res)";
let verifier = [{
if (!getType().isa<fir::RealType>())
return emitOpError("must be a !fir.real type");
@ -2357,7 +2357,7 @@ def fir_ConstcOp : fir_Op<"constc", [NoSideEffect]> {
}];
let results = (outs fir_ComplexType);
let parser = [{
fir::RealAttr realp;
fir::RealAttr imagp;
@ -2455,7 +2455,7 @@ def fir_CmpcOp : fir_Op<"cmpc",
def fir_AddrOfOp : fir_OneResultOp<"address_of", [NoSideEffect]> {
let summary = "convert a symbol to an SSA value";
let description = [{
Convert a symbol (a function or global reference) to an SSA-value to be
used in other Operations.
@ -2474,7 +2474,7 @@ def fir_AddrOfOp : fir_OneResultOp<"address_of", [NoSideEffect]> {
def fir_ConvertOp : fir_OneResultOp<"convert", [NoSideEffect]> {
let summary = "encapsulates all Fortran scalar type conversions";
let description = [{
Generalized type conversion. Convert the ssa value from type T to type U.
Not all pairs of types have conversions. When types T and U are the same
@ -2705,7 +2705,7 @@ def fir_GlobalOp : fir_Op<"global", [IsolatedFromAbove, Symbol]> {
mlir::Type resultType() {
return fir::AllocaOp::wrapResultType(getType());
}
/// Return the initializer attribute if it exists, or a null attribute.
Attribute getValueOrNull() { return initVal().getValueOr(Attribute()); }
@ -2728,9 +2728,9 @@ def fir_GlobalOp : fir_Op<"global", [IsolatedFromAbove, Symbol]> {
}
mlir::FlatSymbolRefAttr getSymbol() {
return mlir::FlatSymbolRefAttr::get(getContext(),
return mlir::FlatSymbolRefAttr::get(
(*this)->getAttrOfType<mlir::StringAttr>(
mlir::SymbolTable::getSymbolAttrName()).getValue());
mlir::SymbolTable::getSymbolAttrName()).getValue(), getContext());
}
}];
}
@ -2772,7 +2772,7 @@ def fir_GlobalLenOp : fir_Op<"global_len", []> {
}];
let printer = [{
p << getOperationName() << ' ' << (*this)->getAttr(lenParamAttrName())
p << getOperationName() << ' ' << (*this)->getAttr(lenParamAttrName())
<< ", " << (*this)->getAttr(intAttrName());
}];

View File

@ -173,7 +173,7 @@ mlir::Value Fortran::lower::FirOpBuilder::createConvert(mlir::Location loc,
fir::StringLitOp Fortran::lower::FirOpBuilder::createStringLit(
mlir::Location loc, mlir::Type eleTy, llvm::StringRef data) {
auto strAttr = mlir::StringAttr::get(getContext(), data);
auto strAttr = mlir::StringAttr::get(data, getContext());
auto valTag = mlir::Identifier::get(fir::StringLitOp::value(), getContext());
mlir::NamedAttribute dataAttr(valTag, strAttr);
auto sizeTag = mlir::Identifier::get(fir::StringLitOp::size(), getContext());

View File

@ -107,7 +107,7 @@ private:
ModuleOp module) {
auto *context = module.getContext();
if (module.lookupSymbol<LLVM::LLVMFuncOp>("printf"))
return SymbolRefAttr::get(context, "printf");
return SymbolRefAttr::get("printf", context);
// Create a function declaration for printf, the signature is:
// * `i32 (i8*, ...)`
@ -120,7 +120,7 @@ private:
PatternRewriter::InsertionGuard insertGuard(rewriter);
rewriter.setInsertionPointToStart(module.getBody());
rewriter.create<LLVM::LLVMFuncOp>(module.getLoc(), "printf", llvmFnType);
return SymbolRefAttr::get(context, "printf");
return SymbolRefAttr::get("printf", context);
}
/// Return a value representing an access into a global string with the given

View File

@ -107,7 +107,7 @@ private:
ModuleOp module) {
auto *context = module.getContext();
if (module.lookupSymbol<LLVM::LLVMFuncOp>("printf"))
return SymbolRefAttr::get(context, "printf");
return SymbolRefAttr::get("printf", context);
// Create a function declaration for printf, the signature is:
// * `i32 (i8*, ...)`
@ -120,7 +120,7 @@ private:
PatternRewriter::InsertionGuard insertGuard(rewriter);
rewriter.setInsertionPointToStart(module.getBody());
rewriter.create<LLVM::LLVMFuncOp>(module.getLoc(), "printf", llvmFnType);
return SymbolRefAttr::get(context, "printf");
return SymbolRefAttr::get("printf", context);
}
/// Return a value representing an access into a global string with the given

View File

@ -31,7 +31,7 @@ inline bool isRowMajorMatmul(ArrayAttr indexingMaps) {
auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, context));
auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, context));
auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, context));
auto maps = ArrayAttr::get(context, {mapA, mapB, mapC});
auto maps = ArrayAttr::get({mapA, mapB, mapC}, context);
return indexingMaps == maps;
}
@ -42,7 +42,7 @@ inline bool isColumnMajorMatmul(ArrayAttr indexingMaps) {
auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, context));
auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, context));
auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {n, m}, context));
auto maps = ArrayAttr::get(context, {mapA, mapB, mapC});
auto maps = ArrayAttr::get({mapA, mapB, mapC}, context);
return indexingMaps == maps;
}

View File

@ -69,7 +69,7 @@ public:
using Base::Base;
using ValueType = ArrayRef<Attribute>;
static ArrayAttr get(MLIRContext *context, ArrayRef<Attribute> value);
static ArrayAttr get(ArrayRef<Attribute> value, MLIRContext *context);
ArrayRef<Attribute> getValue() const;
Attribute operator[](unsigned idx) const;
@ -126,8 +126,8 @@ public:
/// attributes. This method assumes that the provided list is unordered. If
/// the caller can guarantee that the attributes are ordered by name,
/// getWithSorted should be used instead.
static DictionaryAttr get(MLIRContext *context,
ArrayRef<NamedAttribute> value);
static DictionaryAttr get(ArrayRef<NamedAttribute> value,
MLIRContext *context);
/// Construct a dictionary with an array of values that is known to already be
/// sorted by name and uniqued.
@ -250,7 +250,7 @@ public:
using Attribute::Attribute;
using ValueType = bool;
static BoolAttr get(MLIRContext *context, bool value);
static BoolAttr get(bool value, MLIRContext *context);
/// Enable conversion to IntegerAttr. This uses conversion vs. inheritance to
/// avoid bringing in all of IntegerAttrs methods.
@ -292,8 +292,8 @@ public:
using Base::Base;
/// Get or create a new OpaqueAttr with the provided dialect and string data.
static OpaqueAttr get(MLIRContext *context, Identifier dialect,
StringRef attrData, Type type);
static OpaqueAttr get(Identifier dialect, StringRef attrData, Type type,
MLIRContext *context);
/// Get or create a new OpaqueAttr with the provided dialect and string data.
/// If the given identifier is not a valid namespace for a dialect, then a
@ -325,7 +325,7 @@ public:
using ValueType = StringRef;
/// Get an instance of a StringAttr with the given string.
static StringAttr get(MLIRContext *context, StringRef bytes);
static StringAttr get(StringRef bytes, MLIRContext *context);
/// Get an instance of a StringAttr with the given string and Type.
static StringAttr get(StringRef bytes, Type type);
@ -348,12 +348,13 @@ public:
using Base::Base;
/// Construct a symbol reference for the given value name.
static FlatSymbolRefAttr get(MLIRContext *ctx, StringRef value);
static FlatSymbolRefAttr get(StringRef value, MLIRContext *ctx);
/// Construct a symbol reference for the given value name, and a set of nested
/// references that are further resolve to a nested symbol.
static SymbolRefAttr get(MLIRContext *ctx, StringRef value,
ArrayRef<FlatSymbolRefAttr> references);
static SymbolRefAttr get(StringRef value,
ArrayRef<FlatSymbolRefAttr> references,
MLIRContext *ctx);
/// Returns the name of the top level symbol reference, i.e. the root of the
/// reference path.
@ -376,8 +377,8 @@ public:
using ValueType = StringRef;
/// Construct a symbol reference for the given value name.
static FlatSymbolRefAttr get(MLIRContext *ctx, StringRef value) {
return SymbolRefAttr::get(ctx, value);
static FlatSymbolRefAttr get(StringRef value, MLIRContext *ctx) {
return SymbolRefAttr::get(value, ctx);
}
/// Returns the name of the held symbol reference.

View File

@ -569,7 +569,7 @@ void FunctionLike<ConcreteType>::setArgAttrs(
if (attributes.empty())
return (void)static_cast<ConcreteType *>(this)->removeAttr(nameOut);
Operation *op = this->getOperation();
op->setAttr(nameOut, DictionaryAttr::get(op->getContext(), attributes));
op->setAttr(nameOut, DictionaryAttr::get(attributes, op->getContext()));
}
template <typename ConcreteType>

View File

@ -315,7 +315,7 @@ public:
attrs = newAttrs;
}
void setAttrs(ArrayRef<NamedAttribute> newAttrs) {
setAttrs(DictionaryAttr::get(getContext(), newAttrs));
setAttrs(DictionaryAttr::get(newAttrs, getContext()));
}
/// Return the specified attribute if present, null otherwise.

View File

@ -44,7 +44,7 @@ def Symbol : OpInterface<"SymbolOpInterface"> {
/*defaultImplementation=*/[{
this->getOperation()->setAttr(
mlir::SymbolTable::getSymbolAttrName(),
StringAttr::get(this->getOperation()->getContext(), name));
StringAttr::get(name, this->getOperation()->getContext()));
}]
>,
InterfaceMethod<"Gets the visibility of this symbol.",

View File

@ -42,9 +42,9 @@ bool mlirAttributeIsAArray(MlirAttribute attr) {
MlirAttribute mlirArrayAttrGet(MlirContext ctx, intptr_t numElements,
MlirAttribute const *elements) {
SmallVector<Attribute, 8> attrs;
return wrap(
ArrayAttr::get(unwrap(ctx), unwrapList(static_cast<size_t>(numElements),
elements, attrs)));
return wrap(ArrayAttr::get(
unwrapList(static_cast<size_t>(numElements), elements, attrs),
unwrap(ctx)));
}
intptr_t mlirArrayAttrGetNumElements(MlirAttribute attr) {
@ -71,7 +71,7 @@ MlirAttribute mlirDictionaryAttrGet(MlirContext ctx, intptr_t numElements,
attributes.emplace_back(
Identifier::get(unwrap(elements[i].name), unwrap(ctx)),
unwrap(elements[i].attribute));
return wrap(DictionaryAttr::get(unwrap(ctx), attributes));
return wrap(DictionaryAttr::get(attributes, unwrap(ctx)));
}
intptr_t mlirDictionaryAttrGetNumElements(MlirAttribute attr) {
@ -137,7 +137,7 @@ bool mlirAttributeIsABool(MlirAttribute attr) {
}
MlirAttribute mlirBoolAttrGet(MlirContext ctx, int value) {
return wrap(BoolAttr::get(unwrap(ctx), value));
return wrap(BoolAttr::get(value, unwrap(ctx)));
}
bool mlirBoolAttrGetValue(MlirAttribute attr) {
@ -163,9 +163,9 @@ bool mlirAttributeIsAOpaque(MlirAttribute attr) {
MlirAttribute mlirOpaqueAttrGet(MlirContext ctx, MlirStringRef dialectNamespace,
intptr_t dataLength, const char *data,
MlirType type) {
return wrap(OpaqueAttr::get(
unwrap(ctx), Identifier::get(unwrap(dialectNamespace), unwrap(ctx)),
StringRef(data, dataLength), unwrap(type)));
return wrap(
OpaqueAttr::get(Identifier::get(unwrap(dialectNamespace), unwrap(ctx)),
StringRef(data, dataLength), unwrap(type), unwrap(ctx)));
}
MlirStringRef mlirOpaqueAttrGetDialectNamespace(MlirAttribute attr) {
@ -185,7 +185,7 @@ bool mlirAttributeIsAString(MlirAttribute attr) {
}
MlirAttribute mlirStringAttrGet(MlirContext ctx, MlirStringRef str) {
return wrap(StringAttr::get(unwrap(ctx), unwrap(str)));
return wrap(StringAttr::get(unwrap(str), unwrap(ctx)));
}
MlirAttribute mlirStringAttrTypedGet(MlirType type, MlirStringRef str) {
@ -211,7 +211,7 @@ MlirAttribute mlirSymbolRefAttrGet(MlirContext ctx, MlirStringRef symbol,
refs.reserve(numReferences);
for (intptr_t i = 0; i < numReferences; ++i)
refs.push_back(unwrap(references[i]).cast<FlatSymbolRefAttr>());
return wrap(SymbolRefAttr::get(unwrap(ctx), unwrap(symbol), refs));
return wrap(SymbolRefAttr::get(unwrap(symbol), refs, unwrap(ctx)));
}
MlirStringRef mlirSymbolRefAttrGetRootReference(MlirAttribute attr) {
@ -241,7 +241,7 @@ bool mlirAttributeIsAFlatSymbolRef(MlirAttribute attr) {
}
MlirAttribute mlirFlatSymbolRefAttrGet(MlirContext ctx, MlirStringRef symbol) {
return wrap(FlatSymbolRefAttr::get(unwrap(ctx), unwrap(symbol)));
return wrap(FlatSymbolRefAttr::get(unwrap(symbol), unwrap(ctx)));
}
MlirStringRef mlirFlatSymbolRefAttrGetValue(MlirAttribute attr) {

View File

@ -148,7 +148,7 @@ StringAttr GpuKernelToBlobPass::translateGPUModuleToBinaryAnnotation(
auto blob = convertModuleToBlob(llvmModule, loc, name);
if (!blob)
return {};
return StringAttr::get(loc->getContext(), {blob->data(), blob->size()});
return StringAttr::get({blob->data(), blob->size()}, loc->getContext());
}
std::unique_ptr<OperationPass<gpu::GPUModuleOp>>

View File

@ -177,12 +177,12 @@ void ConvertGpuLaunchFuncToVulkanLaunchFunc::convertGpuLaunchFunc(
// Set SPIR-V binary shader data as an attribute.
vulkanLaunchCallOp->setAttr(
kSPIRVBlobAttrName,
StringAttr::get(loc->getContext(), {binary.data(), binary.size()}));
StringAttr::get({binary.data(), binary.size()}, loc->getContext()));
// Set entry point name as an attribute.
vulkanLaunchCallOp->setAttr(
kSPIRVEntryPointAttrName,
StringAttr::get(loc->getContext(), launchOp.getKernelName()));
StringAttr::get(launchOp.getKernelName(), loc->getContext()));
launchOp.erase();
}

View File

@ -687,8 +687,8 @@ public:
rewriter.create<LLVM::ConstantOp>(loc, llvmI32Type, executionModeAttr);
structValue = rewriter.create<LLVM::InsertValueOp>(
loc, structType, structValue, executionMode,
ArrayAttr::get(context,
{rewriter.getIntegerAttr(rewriter.getI32Type(), 0)}));
ArrayAttr::get({rewriter.getIntegerAttr(rewriter.getI32Type(), 0)},
context));
// Insert extra operands if they exist into execution mode info struct.
for (unsigned i = 0, e = values.size(); i < e; ++i) {
@ -696,9 +696,9 @@ public:
Value entry = rewriter.create<LLVM::ConstantOp>(loc, llvmI32Type, attr);
structValue = rewriter.create<LLVM::InsertValueOp>(
loc, structType, structValue, entry,
ArrayAttr::get(context,
{rewriter.getIntegerAttr(rewriter.getI32Type(), 1),
rewriter.getIntegerAttr(rewriter.getI32Type(), i)}));
ArrayAttr::get({rewriter.getIntegerAttr(rewriter.getI32Type(), 1),
rewriter.getIntegerAttr(rewriter.getI32Type(), i)},
context));
}
rewriter.create<LLVM::ReturnOp>(loc, ArrayRef<Value>({structValue}));
rewriter.eraseOp(op);
@ -1297,17 +1297,17 @@ public:
switch (funcOp.function_control()) {
#define DISPATCH(functionControl, llvmAttr) \
case functionControl: \
newFuncOp->setAttr("passthrough", ArrayAttr::get(context, {llvmAttr})); \
newFuncOp->setAttr("passthrough", ArrayAttr::get({llvmAttr}, context)); \
break;
DISPATCH(spirv::FunctionControl::Inline,
StringAttr::get(context, "alwaysinline"));
StringAttr::get("alwaysinline", context));
DISPATCH(spirv::FunctionControl::DontInline,
StringAttr::get(context, "noinline"));
StringAttr::get("noinline", context));
DISPATCH(spirv::FunctionControl::Pure,
StringAttr::get(context, "readonly"));
StringAttr::get("readonly", context));
DISPATCH(spirv::FunctionControl::Const,
StringAttr::get(context, "readnone"));
StringAttr::get("readnone", context));
#undef DISPATCH

View File

@ -4016,7 +4016,7 @@ struct LLVMLoweringPass : public ConvertStandardToLLVMBase<LLVMLoweringPass> {
if (failed(applyPartialConversion(m, target, std::move(patterns))))
signalPassFailure();
m->setAttr(LLVM::LLVMDialect::getDataLayoutAttrName(),
StringAttr::get(m.getContext(), this->dataLayout));
StringAttr::get(this->dataLayout, m.getContext()));
}
};
} // end namespace

View File

@ -762,7 +762,7 @@ public:
if (positionAttrs.size() > 1) {
auto oneDVectorType = reducedVectorTypeBack(vectorType);
auto nMinusOnePositionAttrs =
ArrayAttr::get(context, positionAttrs.drop_back());
ArrayAttr::get(positionAttrs.drop_back(), context);
extracted = rewriter.create<LLVM::ExtractValueOp>(
loc, typeConverter->convertType(oneDVectorType), extracted,
nMinusOnePositionAttrs);
@ -871,7 +871,7 @@ public:
if (positionAttrs.size() > 1) {
oneDVectorType = reducedVectorTypeBack(destVectorType);
auto nMinusOnePositionAttrs =
ArrayAttr::get(context, positionAttrs.drop_back());
ArrayAttr::get(positionAttrs.drop_back(), context);
extracted = rewriter.create<LLVM::ExtractValueOp>(
loc, typeConverter->convertType(oneDVectorType), extracted,
nMinusOnePositionAttrs);
@ -887,7 +887,7 @@ public:
// Potential insertion of resulting 1-D vector into array.
if (positionAttrs.size() > 1) {
auto nMinusOnePositionAttrs =
ArrayAttr::get(context, positionAttrs.drop_back());
ArrayAttr::get(positionAttrs.drop_back(), context);
inserted = rewriter.create<LLVM::InsertValueOp>(loc, llvmResultType,
adaptor.dest(), inserted,
nMinusOnePositionAttrs);

View File

@ -53,7 +53,7 @@ LogicalResult setMappingAttr(scf::ParallelOp ploopOp,
}
ArrayRef<Attribute> mappingAsAttrs(mapping.data(), mapping.size());
ploopOp->setAttr(getMappingAttrName(),
ArrayAttr::get(ploopOp.getContext(), mappingAsAttrs));
ArrayAttr::get(mappingAsAttrs, ploopOp.getContext()));
return success();
}
} // namespace gpu

View File

@ -225,7 +225,7 @@ static void printGenericOp(OpAsmPrinter &p, GenericOpType op) {
if (genericAttrNamesSet.count(attr.first.strref()) > 0)
genericAttrs.push_back(attr);
if (!genericAttrs.empty()) {
auto genericDictAttr = DictionaryAttr::get(op.getContext(), genericAttrs);
auto genericDictAttr = DictionaryAttr::get(genericAttrs, op.getContext());
p << genericDictAttr;
}
@ -833,7 +833,7 @@ static ArrayAttr collapseReassociationMaps(ArrayRef<AffineMap> mapsProducer,
// Handle the corner case of the result being a rank 0 shaped type. Return an
// emtpy ArrayAttr.
if (mapsConsumer.empty() && !mapsProducer.empty())
return ArrayAttr::get(context, ArrayRef<Attribute>());
return ArrayAttr::get(ArrayRef<Attribute>(), context);
if (mapsProducer.empty() || mapsConsumer.empty() ||
mapsProducer[0].getNumDims() < mapsConsumer[0].getNumDims() ||
mapsProducer.size() != mapsConsumer[0].getNumDims())
@ -854,7 +854,7 @@ static ArrayAttr collapseReassociationMaps(ArrayRef<AffineMap> mapsProducer,
numLhsDims, /*numSymbols =*/0, reassociations, context)));
reassociations.clear();
}
return ArrayAttr::get(context, reassociationMaps);
return ArrayAttr::get(reassociationMaps, context);
}
namespace {

View File

@ -137,11 +137,11 @@ static ArrayAttr replaceUnitDims(DenseSet<unsigned> &unitDims,
// wrong, so abort.
if (!inversePermutation(concatAffineMaps(newIndexingMaps)))
return nullptr;
return ArrayAttr::get(context,
llvm::to_vector<4>(llvm::map_range(
newIndexingMaps, [](AffineMap map) -> Attribute {
return AffineMapAttr::get(map);
})));
return ArrayAttr::get(
llvm::to_vector<4>(llvm::map_range(
newIndexingMaps,
[](AffineMap map) -> Attribute { return AffineMapAttr::get(map); })),
context);
}
/// Modify the region of indexed generic op to drop arguments corresponding to
@ -220,7 +220,7 @@ struct FoldUnitDimLoops : public OpRewritePattern<GenericOpTy> {
rewriter.startRootUpdate(op);
op.indexing_mapsAttr(newIndexingMapAttr);
op.iterator_typesAttr(ArrayAttr::get(context, newIteratorTypes));
op.iterator_typesAttr(ArrayAttr::get(newIteratorTypes, context));
(void)replaceBlockArgForUnitDimLoops(op, unitDims, rewriter);
rewriter.finalizeRootUpdate(op);
return success();
@ -282,7 +282,7 @@ static UnitExtentReplacementInfo replaceUnitExtents(AffineMap indexMap,
RankedTensorType::get(newShape, type.getElementType()),
AffineMap::get(indexMap.getNumDims(), indexMap.getNumSymbols(),
newIndexExprs, context),
ArrayAttr::get(context, reassociationMaps)};
ArrayAttr::get(reassociationMaps, context)};
return info;
}

View File

@ -77,9 +77,9 @@ LinalgOp mlir::linalg::interchange(LinalgOp op,
applyPermutationToVector(itTypesVector, interchangeVector);
op->setAttr(getIndexingMapsAttrName(),
ArrayAttr::get(context, newIndexingMaps));
ArrayAttr::get(newIndexingMaps, context));
op->setAttr(getIteratorTypesAttrName(),
ArrayAttr::get(context, itTypesVector));
ArrayAttr::get(itTypesVector, context));
return op;
}

View File

@ -98,7 +98,7 @@ getInterfaceVariables(spirv::FuncOp funcOp,
});
for (auto &var : interfaceVarSet) {
interfaceVars.push_back(SymbolRefAttr::get(
funcOp.getContext(), cast<spirv::GlobalVariableOp>(var).sym_name()));
cast<spirv::GlobalVariableOp>(var).sym_name(), funcOp.getContext()));
}
return success();
}

View File

@ -338,7 +338,7 @@ OpFoldResult AssumingAllOp::fold(ArrayRef<Attribute> operands) {
return a;
}
// If this is reached, all inputs were statically known passing.
return BoolAttr::get(getContext(), true);
return BoolAttr::get(true, getContext());
}
static LogicalResult verify(AssumingAllOp op) {
@ -482,10 +482,10 @@ OpFoldResult CstrBroadcastableOp::fold(ArrayRef<Attribute> operands) {
// Both operands are not needed if one is a scalar.
if (operands[0] &&
operands[0].cast<DenseIntElementsAttr>().getNumElements() == 0)
return BoolAttr::get(getContext(), true);
return BoolAttr::get(true, getContext());
if (operands[1] &&
operands[1].cast<DenseIntElementsAttr>().getNumElements() == 0)
return BoolAttr::get(getContext(), true);
return BoolAttr::get(true, getContext());
if (operands[0] && operands[1]) {
auto lhsShape = llvm::to_vector<6>(
@ -494,7 +494,7 @@ OpFoldResult CstrBroadcastableOp::fold(ArrayRef<Attribute> operands) {
operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
SmallVector<int64_t, 6> resultShape;
if (OpTrait::util::staticallyKnownBroadcastable(lhsShape, rhsShape))
return BoolAttr::get(getContext(), true);
return BoolAttr::get(true, getContext());
}
// Lastly, see if folding can be completed based on what constraints are known
@ -506,7 +506,7 @@ OpFoldResult CstrBroadcastableOp::fold(ArrayRef<Attribute> operands) {
return nullptr;
if (OpTrait::util::staticallyKnownBroadcastable(lhsShape, rhsShape))
return BoolAttr::get(getContext(), true);
return BoolAttr::get(true, getContext());
// Because a failing witness result here represents an eventual assertion
// failure, we do not replace it with a constant witness.
@ -526,7 +526,7 @@ void CstrEqOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns,
OpFoldResult CstrEqOp::fold(ArrayRef<Attribute> operands) {
if (llvm::all_of(operands,
[&](Attribute a) { return a && a == operands[0]; }))
return BoolAttr::get(getContext(), true);
return BoolAttr::get(true, getContext());
// Because a failing witness result here represents an eventual assertion
// failure, we do not try to replace it with a constant witness. Similarly, we
@ -573,14 +573,14 @@ OpFoldResult CstrRequireOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult ShapeEqOp::fold(ArrayRef<Attribute> operands) {
if (lhs() == rhs())
return BoolAttr::get(getContext(), true);
return BoolAttr::get(true, getContext());
auto lhs = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
if (lhs == nullptr)
return {};
auto rhs = operands[1].dyn_cast_or_null<DenseIntElementsAttr>();
if (rhs == nullptr)
return {};
return BoolAttr::get(getContext(), lhs == rhs);
return BoolAttr::get(lhs == rhs, getContext());
}
//===----------------------------------------------------------------------===//

View File

@ -844,7 +844,7 @@ OpFoldResult CmpIOp::fold(ArrayRef<Attribute> operands) {
if (lhs() == rhs()) {
auto val = applyCmpPredicateToEqualOperands(getPredicate());
return BoolAttr::get(getContext(), val);
return BoolAttr::get(val, getContext());
}
auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
@ -853,7 +853,7 @@ OpFoldResult CmpIOp::fold(ArrayRef<Attribute> operands) {
return {};
auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
return BoolAttr::get(getContext(), val);
return BoolAttr::get(val, getContext());
}
//===----------------------------------------------------------------------===//

View File

@ -247,7 +247,7 @@ static void print(OpAsmPrinter &p, ContractionOp op) {
if (traitAttrsSet.count(attr.first.strref()) > 0)
attrs.push_back(attr);
auto dictAttr = DictionaryAttr::get(op.getContext(), attrs);
auto dictAttr = DictionaryAttr::get(attrs, op.getContext());
p << op.getOperationName() << " " << dictAttr << " " << op.lhs() << ", ";
p << op.rhs() << ", " << op.acc();
if (op.masks().size() == 2)
@ -1445,7 +1445,7 @@ static ArrayAttr makeI64ArrayAttr(ArrayRef<int64_t> values,
auto attrs = llvm::map_range(values, [context](int64_t v) -> Attribute {
return IntegerAttr::get(IntegerType::get(context, 64), APInt(64, v));
});
return ArrayAttr::get(context, llvm::to_vector<8>(attrs));
return ArrayAttr::get(llvm::to_vector<8>(attrs), context);
}
static LogicalResult verify(InsertStridedSliceOp op) {

View File

@ -92,11 +92,11 @@ NamedAttribute Builder::getNamedAttr(StringRef name, Attribute val) {
UnitAttr Builder::getUnitAttr() { return UnitAttr::get(context); }
BoolAttr Builder::getBoolAttr(bool value) {
return BoolAttr::get(context, value);
return BoolAttr::get(value, context);
}
DictionaryAttr Builder::getDictionaryAttr(ArrayRef<NamedAttribute> value) {
return DictionaryAttr::get(context, value);
return DictionaryAttr::get(value, context);
}
IntegerAttr Builder::getIndexAttr(int64_t value) {
@ -200,11 +200,11 @@ FloatAttr Builder::getFloatAttr(Type type, const APFloat &value) {
}
StringAttr Builder::getStringAttr(StringRef bytes) {
return StringAttr::get(context, bytes);
return StringAttr::get(bytes, context);
}
ArrayAttr Builder::getArrayAttr(ArrayRef<Attribute> value) {
return ArrayAttr::get(context, value);
return ArrayAttr::get(value, context);
}
FlatSymbolRefAttr Builder::getSymbolRefAttr(Operation *value) {
@ -214,12 +214,12 @@ FlatSymbolRefAttr Builder::getSymbolRefAttr(Operation *value) {
return getSymbolRefAttr(symName.getValue());
}
FlatSymbolRefAttr Builder::getSymbolRefAttr(StringRef value) {
return SymbolRefAttr::get(getContext(), value);
return SymbolRefAttr::get(value, getContext());
}
SymbolRefAttr
Builder::getSymbolRefAttr(StringRef value,
ArrayRef<FlatSymbolRefAttr> nestedReferences) {
return SymbolRefAttr::get(getContext(), value, nestedReferences);
return SymbolRefAttr::get(value, nestedReferences, getContext());
}
ArrayAttr Builder::getBoolArrayAttr(ArrayRef<bool> values) {

View File

@ -35,7 +35,7 @@ AffineMap AffineMapAttr::getValue() const { return getImpl()->value; }
// ArrayAttr
//===----------------------------------------------------------------------===//
ArrayAttr ArrayAttr::get(MLIRContext *context, ArrayRef<Attribute> value) {
ArrayAttr ArrayAttr::get(ArrayRef<Attribute> value, MLIRContext *context) {
return Base::get(context, value);
}
@ -134,8 +134,8 @@ DictionaryAttr::findDuplicate(SmallVectorImpl<NamedAttribute> &array,
return findDuplicateElement(array);
}
DictionaryAttr DictionaryAttr::get(MLIRContext *context,
ArrayRef<NamedAttribute> value) {
DictionaryAttr DictionaryAttr::get(ArrayRef<NamedAttribute> value,
MLIRContext *context) {
if (value.empty())
return DictionaryAttr::getEmpty(context);
assert(llvm::all_of(value,
@ -267,12 +267,13 @@ LogicalResult FloatAttr::verifyConstructionInvariants(Location loc, Type type,
// SymbolRefAttr
//===----------------------------------------------------------------------===//
FlatSymbolRefAttr SymbolRefAttr::get(MLIRContext *ctx, StringRef value) {
FlatSymbolRefAttr SymbolRefAttr::get(StringRef value, MLIRContext *ctx) {
return Base::get(ctx, value, llvm::None).cast<FlatSymbolRefAttr>();
}
SymbolRefAttr SymbolRefAttr::get(MLIRContext *ctx, StringRef value,
ArrayRef<FlatSymbolRefAttr> nestedReferences) {
SymbolRefAttr SymbolRefAttr::get(StringRef value,
ArrayRef<FlatSymbolRefAttr> nestedReferences,
MLIRContext *ctx) {
return Base::get(ctx, value, nestedReferences);
}
@ -293,7 +294,7 @@ ArrayRef<FlatSymbolRefAttr> SymbolRefAttr::getNestedReferences() const {
IntegerAttr IntegerAttr::get(Type type, const APInt &value) {
if (type.isSignlessInteger(1))
return BoolAttr::get(type.getContext(), value.getBoolValue());
return BoolAttr::get(value.getBoolValue(), type.getContext());
return Base::get(type.getContext(), type, value);
}
@ -376,8 +377,8 @@ IntegerSet IntegerSetAttr::getValue() const { return getImpl()->value; }
// OpaqueAttr
//===----------------------------------------------------------------------===//
OpaqueAttr OpaqueAttr::get(MLIRContext *context, Identifier dialect,
StringRef attrData, Type type) {
OpaqueAttr OpaqueAttr::get(Identifier dialect, StringRef attrData, Type type,
MLIRContext *context) {
return Base::get(context, dialect, attrData, type);
}
@ -408,7 +409,7 @@ LogicalResult OpaqueAttr::verifyConstructionInvariants(Location loc,
// StringAttr
//===----------------------------------------------------------------------===//
StringAttr StringAttr::get(MLIRContext *context, StringRef bytes) {
StringAttr StringAttr::get(StringRef bytes, MLIRContext *context) {
return get(bytes, NoneType::get(context));
}

View File

@ -166,7 +166,7 @@ void FuncOp::cloneInto(FuncOp dest, BlockAndValueMapping &mapper) {
newAttrs.insert(attr);
for (auto &attr : getAttrs())
newAttrs.insert(attr);
dest->setAttrs(DictionaryAttr::get(getContext(), newAttrs.takeVector()));
dest->setAttrs(DictionaryAttr::get(newAttrs.takeVector(), getContext()));
// Clone the body.
getBody().cloneInto(&dest.getBody(), mapper);

View File

@ -872,7 +872,7 @@ void AttributeUniquer::initializeAttributeStorage(AttributeStorage *storage,
storage->setType(NoneType::get(ctx));
}
BoolAttr BoolAttr::get(MLIRContext *context, bool value) {
BoolAttr BoolAttr::get(bool value, MLIRContext *context) {
return value ? context->getImpl().trueAttr : context->getImpl().falseAttr;
}

View File

@ -76,7 +76,7 @@ Operation *Operation::create(Location location, OperationName name,
ArrayRef<NamedAttribute> attributes,
BlockRange successors, unsigned numRegions) {
return create(location, name, resultTypes, operands,
DictionaryAttr::get(location.getContext(), attributes),
DictionaryAttr::get(attributes, location.getContext()),
successors, numRegions);
}

View File

@ -46,7 +46,7 @@ collectValidReferencesFor(Operation *symbol, StringRef symbolName,
assert(within->isAncestor(symbol) && "expected 'within' to be an ancestor");
MLIRContext *ctx = symbol->getContext();
auto leafRef = FlatSymbolRefAttr::get(ctx, symbolName);
auto leafRef = FlatSymbolRefAttr::get(symbolName, ctx);
results.push_back(leafRef);
// Early exit for when 'within' is the parent of 'symbol'.
@ -67,13 +67,13 @@ collectValidReferencesFor(Operation *symbol, StringRef symbolName,
getNameIfSymbol(symbolTableOp, symbolNameId);
if (!symbolTableName)
return failure();
results.push_back(SymbolRefAttr::get(ctx, *symbolTableName, nestedRefs));
results.push_back(SymbolRefAttr::get(*symbolTableName, nestedRefs, ctx));
symbolTableOp = symbolTableOp->getParentOp();
if (symbolTableOp == within)
break;
nestedRefs.insert(nestedRefs.begin(),
FlatSymbolRefAttr::get(ctx, *symbolTableName));
FlatSymbolRefAttr::get(*symbolTableName, ctx));
} while (true);
return success();
}
@ -203,7 +203,7 @@ StringRef SymbolTable::getSymbolName(Operation *symbol) {
/// Sets the name of the given symbol operation.
void SymbolTable::setSymbolName(Operation *symbol, StringRef name) {
symbol->setAttr(getSymbolAttrName(),
StringAttr::get(symbol->getContext(), name));
StringAttr::get(name, symbol->getContext()));
}
/// Returns the visibility of the given symbol operation.
@ -235,7 +235,7 @@ void SymbolTable::setSymbolVisibility(Operation *symbol, Visibility vis) {
"unknown symbol visibility kind");
StringRef visName = vis == Visibility::Private ? "private" : "nested";
symbol->setAttr(getVisibilityAttrName(), StringAttr::get(ctx, visName));
symbol->setAttr(getVisibilityAttrName(), StringAttr::get(visName, ctx));
}
/// Returns the nearest symbol table from a given operation `from`. Returns
@ -603,7 +603,7 @@ static SmallVector<SymbolScope, 2> collectSymbolScopes(Operation *symbol,
// doesn't support parent references.
if (SymbolTable::getNearestSymbolTable(limit->getParentOp()) ==
symbol->getParentOp())
return {{SymbolRefAttr::get(symbol->getContext(), symName), limit}};
return {{SymbolRefAttr::get(symName, symbol->getContext()), limit}};
return {};
}
@ -659,7 +659,7 @@ static SmallVector<SymbolScope, 2> collectSymbolScopes(Operation *symbol,
template <typename IRUnit>
static SmallVector<SymbolScope, 1> collectSymbolScopes(StringRef symbol,
IRUnit *limit) {
return {{SymbolRefAttr::get(limit->getContext(), symbol), limit}};
return {{SymbolRefAttr::get(symbol, limit->getContext()), limit}};
}
/// Returns true if the given reference 'SubRef' is a sub reference of the
@ -825,11 +825,11 @@ static Attribute rebuildAttrAfterRAUW(
if (auto dictAttr = container.dyn_cast<DictionaryAttr>()) {
auto newAttrs = llvm::to_vector<4>(dictAttr.getValue());
updateAttrs(make_second_range(newAttrs));
return DictionaryAttr::get(dictAttr.getContext(), newAttrs);
return DictionaryAttr::get(newAttrs, dictAttr.getContext());
}
auto newAttrs = llvm::to_vector<4>(container.cast<ArrayAttr>().getValue());
updateAttrs(newAttrs);
return ArrayAttr::get(container.getContext(), newAttrs);
return ArrayAttr::get(newAttrs, container.getContext());
}
/// Generates a new symbol reference attribute with a new leaf reference.
@ -839,8 +839,8 @@ static SymbolRefAttr generateNewRefAttr(SymbolRefAttr oldAttr,
return newLeafAttr;
auto nestedRefs = llvm::to_vector<2>(oldAttr.getNestedReferences());
nestedRefs.back() = newLeafAttr;
return SymbolRefAttr::get(oldAttr.getContext(), oldAttr.getRootReference(),
nestedRefs);
return SymbolRefAttr::get(oldAttr.getRootReference(), nestedRefs,
oldAttr.getContext());
}
/// The implementation of SymbolTable::replaceAllSymbolUses below.
@ -867,7 +867,7 @@ replaceAllSymbolUsesImpl(SymbolT symbol, StringRef newSymbol, IRUnitT *limit) {
// Generate a new attribute to replace the given attribute.
MLIRContext *ctx = limit->getContext();
FlatSymbolRefAttr newLeafAttr = FlatSymbolRefAttr::get(ctx, newSymbol);
FlatSymbolRefAttr newLeafAttr = FlatSymbolRefAttr::get(newSymbol, ctx);
for (SymbolScope &scope : collectSymbolScopes(symbol, limit)) {
SymbolRefAttr newAttr = generateNewRefAttr(scope.symbol, newLeafAttr);
auto walkFn = [&](SymbolTable::SymbolUse symbolUse,
@ -883,13 +883,13 @@ replaceAllSymbolUsesImpl(SymbolT symbol, StringRef newSymbol, IRUnitT *limit) {
if (useRef != scope.symbol) {
if (scope.symbol.isa<FlatSymbolRefAttr>()) {
replacementRef =
SymbolRefAttr::get(ctx, newSymbol, useRef.getNestedReferences());
SymbolRefAttr::get(newSymbol, useRef.getNestedReferences(), ctx);
} else {
auto nestedRefs = llvm::to_vector<4>(useRef.getNestedReferences());
nestedRefs[scope.symbol.getNestedReferences().size() - 1] =
newLeafAttr;
replacementRef =
SymbolRefAttr::get(ctx, useRef.getRootReference(), nestedRefs);
SymbolRefAttr::get(useRef.getRootReference(), nestedRefs, ctx);
}
}

View File

@ -148,7 +148,7 @@ Attribute Parser::parseAttribute(Type type) {
return Attribute();
return type ? StringAttr::get(val, type)
: StringAttr::get(getContext(), val);
: StringAttr::get(val, getContext());
}
// Parse a symbol reference attribute.
@ -176,7 +176,7 @@ Attribute Parser::parseAttribute(Type type) {
std::string nameStr = getToken().getSymbolReference();
consumeToken(Token::at_identifier);
nestedRefs.push_back(SymbolRefAttr::get(getContext(), nameStr));
nestedRefs.push_back(SymbolRefAttr::get(nameStr, getContext()));
}
return builder.getSymbolRefAttr(nameStr, nestedRefs);

View File

@ -742,8 +742,7 @@ void OpEmitter::genAttrGetters() {
body << " ::mlir::MLIRContext* ctx = getContext();\n";
body << " ::mlir::Builder odsBuilder(ctx); (void)odsBuilder;\n";
body << " return ::mlir::DictionaryAttr::get(";
body << " ctx, {\n";
body << " return ::mlir::DictionaryAttr::get({\n";
interleave(
derivedAttrs, body,
[&](const NamedAttribute &namedAttr) {
@ -756,7 +755,7 @@ void OpEmitter::genAttrGetters() {
<< "}";
},
",\n");
body << "});";
body << "\n }, ctx);";
}
}
}

View File

@ -150,7 +150,7 @@ static void emitFactoryDef(llvm::StringRef structName,
}
const char *getEndInfo = R"(
::mlir::Attribute dict = ::mlir::DictionaryAttr::get(context, fields);
::mlir::Attribute dict = ::mlir::DictionaryAttr::get(fields, context);
return dict.dyn_cast<{0}>();
}
)";

View File

@ -67,7 +67,7 @@ TEST(StructsGenTest, ClassofExtraFalse) {
newValues.push_back(wrongAttr);
// Make a new DictionaryAttr and validate.
auto badDictionary = mlir::DictionaryAttr::get(&context, newValues);
auto badDictionary = mlir::DictionaryAttr::get(newValues, &context);
ASSERT_FALSE(test::TestStruct::classof(badDictionary));
}
@ -88,7 +88,7 @@ TEST(StructsGenTest, ClassofBadNameFalse) {
auto wrongAttr = mlir::NamedAttribute(wrongId, expectedValues[0].second);
newValues.push_back(wrongAttr);
auto badDictionary = mlir::DictionaryAttr::get(&context, newValues);
auto badDictionary = mlir::DictionaryAttr::get(newValues, &context);
ASSERT_FALSE(test::TestStruct::classof(badDictionary));
}
@ -113,7 +113,7 @@ TEST(StructsGenTest, ClassofBadTypeFalse) {
auto wrongAttr = mlir::NamedAttribute(id, elementsAttr);
newValues.push_back(wrongAttr);
auto badDictionary = mlir::DictionaryAttr::get(&context, newValues);
auto badDictionary = mlir::DictionaryAttr::get(newValues, &context);
ASSERT_FALSE(test::TestStruct::classof(badDictionary));
}
@ -130,7 +130,7 @@ TEST(StructsGenTest, ClassofMissingFalse) {
expectedValues.begin() + 1, expectedValues.end());
// Make a new DictionaryAttr and validate it is not a validate TestStruct.
auto badDictionary = mlir::DictionaryAttr::get(&context, newValues);
auto badDictionary = mlir::DictionaryAttr::get(newValues, &context);
ASSERT_FALSE(test::TestStruct::classof(badDictionary));
}