Move getAsmBlockArgumentNames from OpAsmDialectInterface to OpAsmOpInterface
This method is more suitable as an opinterface: it seems intrinsic to individual instances of the operation instead of the dialect. Also remove the restriction on the interface being applicable to the entry block only. Differential Revision: https://reviews.llvm.org/D116018
This commit is contained in:
parent
9c11e95286
commit
7f9e9c7fc3
|
@ -49,7 +49,18 @@ def OpAsmOpInterface : OpInterface<"OpAsmOpInterface"> {
|
|||
}],
|
||||
"void", "getAsmResultNames",
|
||||
(ins "::mlir::OpAsmSetValueNameFn":$setNameFn),
|
||||
"", ";"
|
||||
"", "return;"
|
||||
>,
|
||||
InterfaceMethod<[{
|
||||
Get a special name to use when printing the block arguments for a region
|
||||
immediately nested under this operation.
|
||||
}],
|
||||
"void", "getAsmBlockArgumentNames",
|
||||
(ins
|
||||
"::mlir::Region&":$region,
|
||||
"::mlir::OpAsmSetValueNameFn":$setNameFn
|
||||
),
|
||||
"", "return;"
|
||||
>,
|
||||
StaticInterfaceMethod<[{
|
||||
Return the default dialect used when printing/parsing operations in
|
||||
|
|
|
@ -1348,11 +1348,6 @@ public:
|
|||
/// OpAsmInterface.td#getAsmResultNames for usage details and documentation.
|
||||
virtual void getAsmResultNames(Operation *op,
|
||||
OpAsmSetValueNameFn setNameFn) const {}
|
||||
|
||||
/// Get a special name to use when printing the entry block arguments of the
|
||||
/// region contained by an operation in this dialect.
|
||||
virtual void getAsmBlockArgumentNames(Block *block,
|
||||
OpAsmSetValueNameFn setNameFn) const {}
|
||||
};
|
||||
} // namespace mlir
|
||||
|
||||
|
|
|
@ -1006,6 +1006,20 @@ void SSANameState::shadowRegionArgs(Region ®ion, ValueRange namesToUse) {
|
|||
}
|
||||
|
||||
void SSANameState::numberValuesInRegion(Region ®ion) {
|
||||
auto setBlockArgNameFn = [&](Value arg, StringRef name) {
|
||||
assert(!valueIDs.count(arg) && "arg numbered multiple times");
|
||||
assert(arg.cast<BlockArgument>().getOwner()->getParent() == ®ion &&
|
||||
"arg not defined in current region");
|
||||
setValueName(arg, name);
|
||||
};
|
||||
|
||||
if (!printerFlags.shouldPrintGenericOpForm()) {
|
||||
if (Operation *op = region.getParentOp()) {
|
||||
if (auto asmInterface = dyn_cast<OpAsmOpInterface>(op))
|
||||
asmInterface.getAsmBlockArgumentNames(region, setBlockArgNameFn);
|
||||
}
|
||||
}
|
||||
|
||||
// Number the values within this region in a breadth-first order.
|
||||
unsigned nextBlockID = 0;
|
||||
for (auto &block : region) {
|
||||
|
@ -1017,23 +1031,9 @@ void SSANameState::numberValuesInRegion(Region ®ion) {
|
|||
}
|
||||
|
||||
void SSANameState::numberValuesInBlock(Block &block) {
|
||||
auto setArgNameFn = [&](Value arg, StringRef name) {
|
||||
assert(!valueIDs.count(arg) && "arg numbered multiple times");
|
||||
assert(arg.cast<BlockArgument>().getOwner() == &block &&
|
||||
"arg not defined in 'block'");
|
||||
setValueName(arg, name);
|
||||
};
|
||||
|
||||
bool isEntryBlock = block.isEntryBlock();
|
||||
if (isEntryBlock && !printerFlags.shouldPrintGenericOpForm()) {
|
||||
if (auto *op = block.getParentOp()) {
|
||||
if (auto asmInterface = interfaces.getInterfaceFor(op->getDialect()))
|
||||
asmInterface->getAsmBlockArgumentNames(&block, setArgNameFn);
|
||||
}
|
||||
}
|
||||
|
||||
// Number the block arguments. We give entry block arguments a special name
|
||||
// 'arg'.
|
||||
bool isEntryBlock = block.isEntryBlock();
|
||||
SmallString<32> specialNameBuffer(isEntryBlock ? "arg" : "");
|
||||
llvm::raw_svector_ostream specialName(specialNameBuffer);
|
||||
for (auto arg : block.getArguments()) {
|
||||
|
|
|
@ -105,20 +105,6 @@ struct TestOpAsmInterface : public OpAsmDialectInterface {
|
|||
if (auto asmOp = dyn_cast<AsmDialectInterfaceOp>(op))
|
||||
setNameFn(asmOp, "result");
|
||||
}
|
||||
|
||||
void getAsmBlockArgumentNames(Block *block,
|
||||
OpAsmSetValueNameFn setNameFn) const final {
|
||||
auto op = block->getParentOp();
|
||||
auto arrayAttr = op->getAttrOfType<ArrayAttr>("arg_names");
|
||||
if (!arrayAttr)
|
||||
return;
|
||||
auto args = block->getArguments();
|
||||
auto e = std::min(arrayAttr.size(), args.size());
|
||||
for (unsigned i = 0; i < e; ++i) {
|
||||
if (auto strAttr = arrayAttr[i].dyn_cast<StringAttr>())
|
||||
setNameFn(args[i], strAttr.getValue());
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct TestDialectFoldInterface : public DialectFoldInterface {
|
||||
|
@ -848,6 +834,19 @@ static ParseResult parsePolyForOp(OpAsmParser &parser, OperationState &result) {
|
|||
return parser.parseRegion(*body, ivsInfo, argTypes);
|
||||
}
|
||||
|
||||
void PolyForOp::getAsmBlockArgumentNames(Region ®ion,
|
||||
OpAsmSetValueNameFn setNameFn) {
|
||||
auto arrayAttr = getOperation()->getAttrOfType<ArrayAttr>("arg_names");
|
||||
if (!arrayAttr)
|
||||
return;
|
||||
auto args = getRegion().front().getArguments();
|
||||
auto e = std::min(arrayAttr.size(), args.size());
|
||||
for (unsigned i = 0; i < e; ++i) {
|
||||
if (auto strAttr = arrayAttr[i].dyn_cast<StringAttr>())
|
||||
setNameFn(args[i], strAttr.getValue());
|
||||
}
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Test removing op with inner ops.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -1667,13 +1667,16 @@ def PrettyPrintedRegionOp : TEST_Op<"pretty_printed_region",
|
|||
let printer = [{ return ::print(p, *this); }];
|
||||
}
|
||||
|
||||
def PolyForOp : TEST_Op<"polyfor">
|
||||
def PolyForOp : TEST_Op<"polyfor", [OpAsmOpInterface]>
|
||||
{
|
||||
let summary = "polyfor operation";
|
||||
let description = [{
|
||||
Test op with multiple region arguments, each argument of index type.
|
||||
}];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
void getAsmBlockArgumentNames(mlir::Region ®ion,
|
||||
mlir::OpAsmSetValueNameFn setNameFn);
|
||||
}];
|
||||
let regions = (region SizedRegion<1>:$region);
|
||||
let parser = [{ return ::parse$cppClass(parser, result); }];
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue