Move getAsmBlockArgumentNames from OpAsmDialectInterface to OpAsmOpInterface

This method is more suitable as an opinterface: it seems intrinsic to
individual instances of the operation instead of the dialect.
Also remove the restriction on the interface being applicable to the entry block only.

Differential Revision: https://reviews.llvm.org/D116018
This commit is contained in:
Mehdi Amini 2021-12-20 07:17:26 +00:00
parent 9c11e95286
commit 7f9e9c7fc3
5 changed files with 45 additions and 37 deletions

View File

@ -49,7 +49,18 @@ def OpAsmOpInterface : OpInterface<"OpAsmOpInterface"> {
}],
"void", "getAsmResultNames",
(ins "::mlir::OpAsmSetValueNameFn":$setNameFn),
"", ";"
"", "return;"
>,
InterfaceMethod<[{
Get a special name to use when printing the block arguments for a region
immediately nested under this operation.
}],
"void", "getAsmBlockArgumentNames",
(ins
"::mlir::Region&":$region,
"::mlir::OpAsmSetValueNameFn":$setNameFn
),
"", "return;"
>,
StaticInterfaceMethod<[{
Return the default dialect used when printing/parsing operations in

View File

@ -1348,11 +1348,6 @@ public:
/// OpAsmInterface.td#getAsmResultNames for usage details and documentation.
virtual void getAsmResultNames(Operation *op,
OpAsmSetValueNameFn setNameFn) const {}
/// Get a special name to use when printing the entry block arguments of the
/// region contained by an operation in this dialect.
virtual void getAsmBlockArgumentNames(Block *block,
OpAsmSetValueNameFn setNameFn) const {}
};
} // namespace mlir

View File

@ -1006,6 +1006,20 @@ void SSANameState::shadowRegionArgs(Region &region, ValueRange namesToUse) {
}
void SSANameState::numberValuesInRegion(Region &region) {
auto setBlockArgNameFn = [&](Value arg, StringRef name) {
assert(!valueIDs.count(arg) && "arg numbered multiple times");
assert(arg.cast<BlockArgument>().getOwner()->getParent() == &region &&
"arg not defined in current region");
setValueName(arg, name);
};
if (!printerFlags.shouldPrintGenericOpForm()) {
if (Operation *op = region.getParentOp()) {
if (auto asmInterface = dyn_cast<OpAsmOpInterface>(op))
asmInterface.getAsmBlockArgumentNames(region, setBlockArgNameFn);
}
}
// Number the values within this region in a breadth-first order.
unsigned nextBlockID = 0;
for (auto &block : region) {
@ -1017,23 +1031,9 @@ void SSANameState::numberValuesInRegion(Region &region) {
}
void SSANameState::numberValuesInBlock(Block &block) {
auto setArgNameFn = [&](Value arg, StringRef name) {
assert(!valueIDs.count(arg) && "arg numbered multiple times");
assert(arg.cast<BlockArgument>().getOwner() == &block &&
"arg not defined in 'block'");
setValueName(arg, name);
};
bool isEntryBlock = block.isEntryBlock();
if (isEntryBlock && !printerFlags.shouldPrintGenericOpForm()) {
if (auto *op = block.getParentOp()) {
if (auto asmInterface = interfaces.getInterfaceFor(op->getDialect()))
asmInterface->getAsmBlockArgumentNames(&block, setArgNameFn);
}
}
// Number the block arguments. We give entry block arguments a special name
// 'arg'.
bool isEntryBlock = block.isEntryBlock();
SmallString<32> specialNameBuffer(isEntryBlock ? "arg" : "");
llvm::raw_svector_ostream specialName(specialNameBuffer);
for (auto arg : block.getArguments()) {

View File

@ -105,20 +105,6 @@ struct TestOpAsmInterface : public OpAsmDialectInterface {
if (auto asmOp = dyn_cast<AsmDialectInterfaceOp>(op))
setNameFn(asmOp, "result");
}
void getAsmBlockArgumentNames(Block *block,
OpAsmSetValueNameFn setNameFn) const final {
auto op = block->getParentOp();
auto arrayAttr = op->getAttrOfType<ArrayAttr>("arg_names");
if (!arrayAttr)
return;
auto args = block->getArguments();
auto e = std::min(arrayAttr.size(), args.size());
for (unsigned i = 0; i < e; ++i) {
if (auto strAttr = arrayAttr[i].dyn_cast<StringAttr>())
setNameFn(args[i], strAttr.getValue());
}
}
};
struct TestDialectFoldInterface : public DialectFoldInterface {
@ -848,6 +834,19 @@ static ParseResult parsePolyForOp(OpAsmParser &parser, OperationState &result) {
return parser.parseRegion(*body, ivsInfo, argTypes);
}
void PolyForOp::getAsmBlockArgumentNames(Region &region,
OpAsmSetValueNameFn setNameFn) {
auto arrayAttr = getOperation()->getAttrOfType<ArrayAttr>("arg_names");
if (!arrayAttr)
return;
auto args = getRegion().front().getArguments();
auto e = std::min(arrayAttr.size(), args.size());
for (unsigned i = 0; i < e; ++i) {
if (auto strAttr = arrayAttr[i].dyn_cast<StringAttr>())
setNameFn(args[i], strAttr.getValue());
}
}
//===----------------------------------------------------------------------===//
// Test removing op with inner ops.
//===----------------------------------------------------------------------===//

View File

@ -1667,13 +1667,16 @@ def PrettyPrintedRegionOp : TEST_Op<"pretty_printed_region",
let printer = [{ return ::print(p, *this); }];
}
def PolyForOp : TEST_Op<"polyfor">
def PolyForOp : TEST_Op<"polyfor", [OpAsmOpInterface]>
{
let summary = "polyfor operation";
let description = [{
Test op with multiple region arguments, each argument of index type.
}];
let extraClassDeclaration = [{
void getAsmBlockArgumentNames(mlir::Region &region,
mlir::OpAsmSetValueNameFn setNameFn);
}];
let regions = (region SizedRegion<1>:$region);
let parser = [{ return ::parse$cppClass(parser, result); }];
}