[mlir][ods] Format: allow anchors in the else elements

This patch changes optional groups to allow anchors in the 'else'
element group. When printing, the optional condition is inverted to
decide which group to print. This is useful for parsing concrete
optional elements that don't have a `parseOptional*` method or some
other way to test whether it's present.

Depends on D133805

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D133812
This commit is contained in:
Jeff Niu 2022-09-13 15:08:39 -07:00
parent 52a479de60
commit 95a33b455d
11 changed files with 256 additions and 106 deletions

View File

@ -660,8 +660,9 @@ set to `llvm::None` and `Attribute` will be set to `nullptr`.
Only optional parameters or directives that only capture optional parameters can
be used in optional groups. An optional group is a set of elements optionally
printed based on the presence of an anchor. Suppose parameter `a` is an
`IntegerAttr`.
printed based on the presence of an anchor. The group in which the anchor is
placed is printed if it is present, otherwise the other one is printed. Suppose
parameter `a` is an `IntegerAttr`.
```
( `(` $a^ `)` ) : (`x`)?

View File

@ -856,17 +856,18 @@ of the assembly format can be marked as `optional` based on the presence of this
information. An optional group is defined as follows:
```
optional-group: `(` elements `)` (`:` `(` else-elements `)`)? `?`
optional-group: `(` then-elements `)` (`:` `(` else-elements `)`)? `?`
```
The `elements` of an optional group have the following requirements:
The elements of an optional group have the following requirements:
* The first element of the group must either be a attribute, literal, operand,
or region.
* The first element of `then-elements` must either be a attribute, literal,
operand, or region.
- This is because the first element must be optionally parsable.
* Exactly one argument variable or type directive within the group must be
marked as the anchor of the group.
- The anchor is the element whose presence controls whether the group
* Exactly one argument variable or type directive within either
`then-elements` or `else-elements` must be marked as the anchor of the
group.
- The anchor is the element whose presence controls which elements
should be printed/parsed.
- An element is marked as the anchor by adding a trailing `^`.
- The first element is *not* required to be the anchor of the group.

View File

@ -797,6 +797,11 @@ def OIListAllowedLiteral : TEST_Op<"oilist_allowed_literal"> {
}];
}
def ElseAnchorOp : TEST_Op<"else_anchor"> {
let arguments = (ins Optional<AnyType>:$a);
let assemblyFormat = "`(` (`?`) : (`` $a^ `:` type($a))? `)` attr-dict";
}
// This is used to test encoding of a string attribute into an SSA name of a
// pretty printed value name.
def StringAttrPrettyNameOp

View File

@ -332,4 +332,17 @@ def TestTypeCustomString : Test_Type<"TestTypeCustomString"> {
custom<BarString>(ref($foo)) `>` }];
}
def TestTypeElseAnchor : Test_Type<"TestTypeElseAnchor"> {
let parameters = (ins OptionalParameter<"mlir::Optional<int>">:$a);
let mnemonic = "else_anchor";
let assemblyFormat = "`<` (`?`) : ($a^)? `>`";
}
def TestTypeElseAnchorStruct : Test_Type<"TestTypeElseAnchorStruct"> {
let parameters = (ins OptionalParameter<"mlir::Optional<int>">:$a,
OptionalParameter<"mlir::Optional<int>">:$b);
let mnemonic = "else_anchor_struct";
let assemblyFormat = "`<` (`?`) : (struct($a, $b)^)? `>`";
}
#endif // TEST_TYPEDEFS

View File

@ -571,3 +571,39 @@ def TypeL : TestType<"TestN"> {
let mnemonic = "type_l";
let assemblyFormat = [{ custom<Foo>($a, "1") }];
}
// TYPE-LABEL: ::mlir::Type TestOType::parse
// TYPE: if (odsParser.parseOptionalQuestion())
// TYPE: _result_a =
// TYPE: else
// TYPE-LABEL: void TestOType::print
// TYPE: if (!((getA())))
// TYPE: odsPrinter << ' ' << "?"
// TYPE: else
// TYPE: odsPrinter.printStrippedAttrOrType(getA())
def TypeM : TestType<"TestO"> {
let parameters = (ins OptionalParameter<"int">:$a);
let mnemonic = "type_m";
let assemblyFormat = "(`?`) : ($a^)?";
}
// TYPE-LABEL: ::mlir::Type TestPType::parse
// TYPE: if (odsParser.parseOptionalQuestion())
// TYPE: bool _seen_a
// TYPE: bool _seen_b
// TYPE: _loop_body(_paramKey))
// TYPE: else {
// TYPE-NEXT: }
// TYPE-LABEL: void TestPType::print
// TYPE: if (!((getA()) || (getB())))
// TYPE-NEXT: odsPrinter << "?"
def TypeN : TestType<"TestP"> {
let parameters = (ins OptionalParameter<"int">:$a,
OptionalParameter<"int">:$b);
let mnemonic = "type_n";
let assemblyFormat = "`<` (`?`) : (struct($a, $b)^)? `>`";
}

View File

@ -463,3 +463,18 @@ test.format_infer_variadic_type_from_non_variadic %i64, %i64 : i64
// CHECK: test.has_str_value
test.has_str_value {}
//===----------------------------------------------------------------------===//
// ElseAnchorOp
//===----------------------------------------------------------------------===//
// CHECK-LABEL: @else_anchor_op
func.func @else_anchor_op(%a: !test.else_anchor<?>, %b: !test.else_anchor<5>) {
// CHECK: test.else_anchor(?) {a = !test.else_anchor_struct<?>}
test.else_anchor(?) {a = !test.else_anchor_struct<?>}
// CHECK: test.else_anchor(%{{.*}} : !test.else_anchor<?>) {a = !test.else_anchor_struct<a = 0>}
test.else_anchor(%a : !test.else_anchor<?>) {a = !test.else_anchor_struct<a = 0>}
// CHECK: test.else_anchor(%{{.*}} : !test.else_anchor<5>) {a = !test.else_anchor_struct<b = 0>}
test.else_anchor(%b : !test.else_anchor<5>) {a = !test.else_anchor_struct<b = 0>}
return
}

View File

@ -40,3 +40,34 @@ def CustomStringLiteralB : TestFormat_Op<[{
def CustomStringLiteralC : TestFormat_Op<[{
custom<Foo>("$_builder.getStringAttr(\"foo\")") attr-dict
}]>;
//===----------------------------------------------------------------------===//
// Optional Groups
//===----------------------------------------------------------------------===//
// CHECK-LABEL: OptionalGroupA::parse
// CHECK: if (::mlir::succeeded(parser.parseOptionalQuestion())
// CHECK-NEXT: else
// CHECK: parser.parseOptionalOperand
// CHECK-LABEL: OptionalGroupA::print
// CHECK: if (!getA())
// CHECK-NEXT: odsPrinter << ' ' << "?";
// CHECK-NEXT: else
// CHECK: odsPrinter << value;
def OptionalGroupA : TestFormat_Op<[{
(`?`) : ($a^)? attr-dict
}]>, Arguments<(ins Optional<I1>:$a)>;
// CHECK-LABEL: OptionalGroupB::parse
// CHECK: if (::mlir::succeeded(parser.parseOptionalKeyword("foo")))
// CHECK-NEXT: else
// CHECK-NEXT: result.addAttribute("a", parser.getBuilder().getUnitAttr())
// CHECK: parser.parseKeyword("bar")
// CHECK-LABEL: OptionalGroupB::print
// CHECK: if (!(*this)->getAttr("a"))
// CHECK-NEXT: odsPrinter << ' ' << "foo"
// CHECK-NEXT: else
// CHECK-NEXT: odsPrinter << ' ' << "bar"
def OptionalGroupB : TestFormat_Op<[{
(`foo`) : (`bar` $a^)? attr-dict
}]>, Arguments<(ins UnitAttr:$a)>;

View File

@ -656,10 +656,10 @@ void DefFormat::genCustomParser(CustomDirective *el, FmtContext &ctx,
void DefFormat::genOptionalGroupParser(OptionalElement *el, FmtContext &ctx,
MethodBody &os) {
ArrayRef<FormatElement *> elements =
el->getThenElements().drop_front(el->getParseStart());
ArrayRef<FormatElement *> thenElements =
el->getThenElements(/*parseable=*/true);
FormatElement *first = elements.front();
FormatElement *first = thenElements.front();
const auto guardOn = [&](auto params) {
os << "if (!(";
llvm::interleave(
@ -687,12 +687,12 @@ void DefFormat::genOptionalGroupParser(OptionalElement *el, FmtContext &ctx,
}
os.indent();
// Generate the parsers for the rest of the elements.
for (FormatElement *element : el->getElseElements())
// Generate the parsers for the rest of the thenElements.
for (FormatElement *element : el->getElseElements(/*parseable=*/true))
genElementParser(element, ctx, os);
os.unindent() << "} else {\n";
os.indent();
for (FormatElement *element : elements.drop_front())
for (FormatElement *element : thenElements.drop_front())
genElementParser(element, ctx, os);
os.unindent() << "}\n";
}
@ -781,12 +781,16 @@ void DefFormat::genVariablePrinter(ParameterElement *el, FmtContext &ctx,
/// Generate code to guard printing on the presence of any optional parameters.
template <typename ParameterRange>
static void guardOnAny(FmtContext &ctx, MethodBody &os,
ParameterRange &&params) {
static void guardOnAny(FmtContext &ctx, MethodBody &os, ParameterRange &&params,
bool inverted = false) {
os << "if (";
if (inverted)
os << "!(";
llvm::interleave(
params, os,
[&](ParameterElement *param) { param->genPrintGuard(ctx, os); }, " || ");
if (inverted)
os << ")";
os << ") {\n";
os.indent();
}
@ -860,12 +864,12 @@ void DefFormat::genOptionalGroupPrinter(OptionalElement *el, FmtContext &ctx,
MethodBody &os) {
FormatElement *anchor = el->getAnchor();
if (auto *param = dyn_cast<ParameterElement>(anchor)) {
guardOnAny(ctx, os, llvm::makeArrayRef(param));
guardOnAny(ctx, os, llvm::makeArrayRef(param), el->isInverted());
} else if (auto *params = dyn_cast<ParamsDirective>(anchor)) {
guardOnAny(ctx, os, params->getParams());
guardOnAny(ctx, os, params->getParams(), el->isInverted());
} else {
auto *strct = cast<StructDirective>(anchor);
guardOnAny(ctx, os, strct->getParams());
guardOnAny(ctx, os, strct->getParams(), el->isInverted());
}
// Generate the printer for the contained elements.
{

View File

@ -321,35 +321,42 @@ FailureOr<FormatElement *> FormatParser::parseOptionalGroup(Context ctx) {
// Parse the child elements for this optional group.
std::vector<FormatElement *> thenElements, elseElements;
FormatElement *anchor = nullptr;
do {
FailureOr<FormatElement *> element = parseElement(TopLevelContext);
if (failed(element))
return failure();
// Check for an anchor.
if (curToken.is(FormatToken::caret)) {
if (anchor)
return emitError(curToken.getLoc(), "only one element can be marked as "
"the anchor of an optional group");
anchor = *element;
consumeToken();
}
thenElements.push_back(*element);
} while (!curToken.is(FormatToken::r_paren));
consumeToken();
// Parse the `else` elements of this optional group.
if (curToken.is(FormatToken::colon)) {
consumeToken();
if (failed(
parseToken(FormatToken::l_paren,
"expected '(' to start else branch of optional group")))
return failure();
auto parseChildElements =
[this, &anchor](std::vector<FormatElement *> &elements) -> LogicalResult {
do {
FailureOr<FormatElement *> element = parseElement(TopLevelContext);
if (failed(element))
return failure();
elseElements.push_back(*element);
// Check for an anchor.
if (curToken.is(FormatToken::caret)) {
if (anchor) {
return emitError(curToken.getLoc(),
"only one element can be marked as the anchor of an "
"optional group");
}
anchor = *element;
consumeToken();
}
elements.push_back(*element);
} while (!curToken.is(FormatToken::r_paren));
return success();
};
// Parse the 'then' elements. If the anchor was found in this group, then the
// optional is not inverted.
if (failed(parseChildElements(thenElements)))
return failure();
consumeToken();
bool inverted = !anchor;
// Parse the `else` elements of this optional group.
if (curToken.is(FormatToken::colon)) {
consumeToken();
if (failed(parseToken(
FormatToken::l_paren,
"expected '(' to start else branch of optional group")) ||
failed(parseChildElements(elseElements)))
return failure();
consumeToken();
}
if (failed(parseToken(FormatToken::question,
@ -367,17 +374,21 @@ FailureOr<FormatElement *> FormatParser::parseOptionalGroup(Context ctx) {
// Get the first parsable element. It must be an element that can be
// optionally-parsed.
auto parseBegin = llvm::find_if_not(thenElements, [](FormatElement *element) {
auto isWhitespace = [](FormatElement *element) {
return isa<WhitespaceElement>(element);
});
if (!isa<LiteralElement, VariableElement>(*parseBegin)) {
};
auto thenParseBegin = llvm::find_if_not(thenElements, isWhitespace);
auto elseParseBegin = llvm::find_if_not(elseElements, isWhitespace);
unsigned thenParseStart = std::distance(thenElements.begin(), thenParseBegin);
unsigned elseParseStart = std::distance(elseElements.begin(), elseParseBegin);
if (!isa<LiteralElement, VariableElement>(*thenParseBegin)) {
return emitError(loc, "first parsable element of an optional group must be "
"a literal or variable");
}
unsigned parseStart = std::distance(thenElements.begin(), parseBegin);
return create<OptionalElement>(std::move(thenElements),
std::move(elseElements), anchor, parseStart);
std::move(elseElements), thenParseStart,
elseParseStart, anchor, inverted);
}
FailureOr<FormatElement *> FormatParser::parseCustomDirective(SMLoc loc,

View File

@ -378,33 +378,48 @@ public:
/// Create an optional group with the given child elements.
OptionalElement(std::vector<FormatElement *> &&thenElements,
std::vector<FormatElement *> &&elseElements,
FormatElement *anchor, unsigned parseStart)
unsigned thenParseStart, unsigned elseParseStart,
FormatElement *anchor, bool inverted)
: thenElements(std::move(thenElements)),
elseElements(std::move(elseElements)), anchor(anchor),
parseStart(parseStart) {}
elseElements(std::move(elseElements)), thenParseStart(thenParseStart),
elseParseStart(elseParseStart), anchor(anchor), inverted(inverted) {}
/// Return the `then` elements of the optional group.
ArrayRef<FormatElement *> getThenElements() const { return thenElements; }
/// Return the `then` elements of the optional group. Drops the first
/// `thenParseStart` whitespace elements if `parseable` is true.
ArrayRef<FormatElement *> getThenElements(bool parseable = false) const {
return llvm::makeArrayRef(thenElements)
.drop_front(parseable ? thenParseStart : 0);
}
/// Return the `else` elements of the optional group.
ArrayRef<FormatElement *> getElseElements() const { return elseElements; }
/// Return the `else` elements of the optional group. Drops the first
/// `elseParseStart` whitespace elements if `parseable` is true.
ArrayRef<FormatElement *> getElseElements(bool parseable = false) const {
return llvm::makeArrayRef(elseElements)
.drop_front(parseable ? elseParseStart : 0);
}
/// Return the anchor of the optional group.
FormatElement *getAnchor() const { return anchor; }
/// Return the index of the first element to be parsed.
unsigned getParseStart() const { return parseStart; }
/// Return true if the optional group is inverted.
bool isInverted() const { return inverted; }
private:
/// The child elements emitted when the anchor is present.
std::vector<FormatElement *> thenElements;
/// The child elements emitted when the anchor is not present.
std::vector<FormatElement *> elseElements;
/// The anchor element of the optional group.
FormatElement *anchor;
/// The index of the first element that is parsed in `thenElements`. That is,
/// the first non-whitespace element.
unsigned parseStart;
unsigned thenParseStart;
/// The index of the first element that is parsed in `elseElements`. That is,
/// the first non-whitespace element.
unsigned elseParseStart;
/// The anchor element of the optional group.
FormatElement *anchor;
/// Whether the optional group condition is inverted and the anchor element is
/// in the else group.
bool inverted;
};
//===----------------------------------------------------------------------===//

View File

@ -1119,17 +1119,43 @@ void OperationFormat::genElementParser(FormatElement *element, MethodBody &body,
GenContext genCtx) {
/// Optional Group.
if (auto *optional = dyn_cast<OptionalElement>(element)) {
ArrayRef<FormatElement *> elements =
optional->getThenElements().drop_front(optional->getParseStart());
auto genElementParsers = [&](FormatElement *firstElement,
ArrayRef<FormatElement *> elements,
bool thenGroup) {
// If the anchor is a unit attribute, we don't need to print it. When
// parsing, we will add this attribute if this group is present.
FormatElement *elidedAnchorElement = nullptr;
auto *anchorAttr = dyn_cast<AttributeVariable>(optional->getAnchor());
if (anchorAttr && anchorAttr != firstElement &&
anchorAttr->isUnitAttr()) {
elidedAnchorElement = anchorAttr;
if (!thenGroup == optional->isInverted()) {
// Add the anchor unit attribute to the operation state.
body << " result.addAttribute(\"" << anchorAttr->getVar()->name
<< "\", parser.getBuilder().getUnitAttr());\n";
}
}
// Generate the rest of the elements inside an optional group. Elements in
// an optional group after the guard are parsed as required.
for (FormatElement *childElement : elements)
if (childElement != elidedAnchorElement)
genElementParser(childElement, body, attrTypeCtx,
GenContext::Optional);
};
ArrayRef<FormatElement *> thenElements =
optional->getThenElements(/*parseable=*/true);
// Generate a special optional parser for the first element to gate the
// parsing of the rest of the elements.
FormatElement *firstElement = elements.front();
FormatElement *firstElement = thenElements.front();
if (auto *attrVar = dyn_cast<AttributeVariable>(firstElement)) {
genElementParser(attrVar, body, attrTypeCtx);
body << " if (" << attrVar->getVar()->name << "Attr) {\n";
} else if (auto *literal = dyn_cast<LiteralElement>(firstElement)) {
body << " if (succeeded(parser.parseOptional";
body << " if (::mlir::succeeded(parser.parseOptional";
genLiteralParser(literal->getSpelling(), body);
body << ")) {\n";
} else if (auto *opVar = dyn_cast<OperandVariable>(firstElement)) {
@ -1151,31 +1177,18 @@ void OperationFormat::genElementParser(FormatElement *element, MethodBody &body,
}
}
// If the anchor is a unit attribute, we don't need to print it. When
// parsing, we will add this attribute if this group is present.
FormatElement *elidedAnchorElement = nullptr;
auto *anchorAttr = dyn_cast<AttributeVariable>(optional->getAnchor());
if (anchorAttr && anchorAttr != firstElement && anchorAttr->isUnitAttr()) {
elidedAnchorElement = anchorAttr;
// Add the anchor unit attribute to the operation state.
body << " result.addAttribute(\"" << anchorAttr->getVar()->name
<< "\", parser.getBuilder().getUnitAttr());\n";
}
// Generate the rest of the elements inside an optional group. Elements in
// an optional group after the guard are parsed as required.
for (FormatElement *childElement : llvm::drop_begin(elements, 1))
if (childElement != elidedAnchorElement)
genElementParser(childElement, body, attrTypeCtx, GenContext::Optional);
genElementParsers(firstElement, thenElements.drop_front(),
/*thenGroup=*/true);
body << " }";
// Generate the else elements.
auto elseElements = optional->getElseElements();
if (!elseElements.empty()) {
body << " else {\n";
for (FormatElement *childElement : elseElements)
genElementParser(childElement, body, attrTypeCtx);
ArrayRef<FormatElement *> elseElements =
optional->getElseElements(/*parsable=*/true);
genElementParsers(elseElements.front(), elseElements,
/*thenGroup=*/false);
body << " }";
}
body << "\n";
@ -1842,15 +1855,15 @@ static void genOptionalGroupPrinterAnchor(FormatElement *anchor,
const NamedTypeConstraint *var = element->getVar();
std::string name = op.getGetterName(var->name);
if (var->isOptional())
body << " if (" << name << "()) {\n";
body << name << "()";
else if (var->isVariadic())
body << " if (!" << name << "().empty()) {\n";
body << "!" << name << "().empty()";
})
.Case<RegionVariable>([&](RegionVariable *element) {
const NamedRegion *var = element->getVar();
std::string name = op.getGetterName(var->name);
// TODO: Add a check for optional regions here when ODS supports it.
body << " if (!" << name << "().empty()) {\n";
body << "!" << name << "().empty()";
})
.Case<TypeDirective>([&](TypeDirective *element) {
genOptionalGroupPrinterAnchor(element->getArg(), op, body);
@ -1859,8 +1872,7 @@ static void genOptionalGroupPrinterAnchor(FormatElement *anchor,
genOptionalGroupPrinterAnchor(element->getInputs(), op, body);
})
.Case<AttributeVariable>([&](AttributeVariable *attr) {
body << " if ((*this)->getAttr(\"" << attr->getVar()->name
<< "\")) {\n";
body << "(*this)->getAttr(\"" << attr->getVar()->name << "\")";
});
}
@ -1912,39 +1924,45 @@ void OperationFormat::genElementPrinter(FormatElement *element,
if (OptionalElement *optional = dyn_cast<OptionalElement>(element)) {
// Emit the check for the presence of the anchor element.
FormatElement *anchor = optional->getAnchor();
body << " if (";
if (optional->isInverted())
body << "!";
genOptionalGroupPrinterAnchor(anchor, op, body);
body << ") {\n";
body.indent();
// If the anchor is a unit attribute, we don't need to print it. When
// parsing, we will add this attribute if this group is present.
auto elements = optional->getThenElements();
ArrayRef<FormatElement *> thenElements = optional->getThenElements();
ArrayRef<FormatElement *> elseElements = optional->getElseElements();
FormatElement *elidedAnchorElement = nullptr;
auto *anchorAttr = dyn_cast<AttributeVariable>(anchor);
if (anchorAttr && anchorAttr != elements.front() &&
if (anchorAttr && anchorAttr != thenElements.front() &&
(elseElements.empty() || anchorAttr != elseElements.front()) &&
anchorAttr->isUnitAttr()) {
elidedAnchorElement = anchorAttr;
}
auto genElementPrinters = [&](ArrayRef<FormatElement *> elements) {
for (FormatElement *childElement : elements) {
if (childElement != elidedAnchorElement) {
genElementPrinter(childElement, body, op, shouldEmitSpace,
lastWasPunctuation);
}
}
};
// Emit each of the elements.
for (FormatElement *childElement : elements) {
if (childElement != elidedAnchorElement) {
genElementPrinter(childElement, body, op, shouldEmitSpace,
lastWasPunctuation);
}
}
body << " }";
genElementPrinters(thenElements);
body << "}";
// Emit each of the else elements.
auto elseElements = optional->getElseElements();
if (!elseElements.empty()) {
body << " else {\n";
for (FormatElement *childElement : elseElements) {
genElementPrinter(childElement, body, op, shouldEmitSpace,
lastWasPunctuation);
}
body << " }";
genElementPrinters(elseElements);
body << "}";
}
body << "\n";
body.unindent() << "\n";
return;
}