[mlir] Printing oilist element

This patch attempts to deduce when the oilist element must be printed
based on the optional arguments to it. This especially helps creating
an operation accurately because with the current implementation, the
inferred unit attributes must be manually added to print the clauses
appropriately.

Reviewed By: Mogball

Differential Revision: https://reviews.llvm.org/D121579
This commit is contained in:
Shraiysh Vaishay 2022-03-22 10:08:33 +05:30
parent f863df9a05
commit ddc90da478
4 changed files with 140 additions and 65 deletions

View File

@ -199,7 +199,7 @@ def SectionsOp : OpenMP_Op<"sections", [AttrSizedOperandSegments]> {
$allocate_vars, type($allocate_vars),
$allocators_vars, type($allocators_vars)
) `)`
| `nowait`
| `nowait` $nowait
) $region attr-dict
}];
@ -438,7 +438,7 @@ def TargetOp : OpenMP_Op<"target",[AttrSizedOperandSegments]> {
oilist( `if` `(` $if_expr `)`
| `device` `(` $device `:` type($device) `)`
| `thread_limit` `(` $thread_limit `:` type($thread_limit) `)`
| `nowait`
| `nowait` $nowait
) $region attr-dict
}];
}

View File

@ -498,6 +498,10 @@ func @succeededOilistTrivial() {
test.oilist_with_keywords_only keyword otherKeyword
// CHECK: test.oilist_with_keywords_only keyword otherKeyword
test.oilist_with_keywords_only otherKeyword keyword
// CHECK: test.oilist_with_keywords_only thirdKeyword
test.oilist_with_keywords_only thirdKeyword
// CHECK: test.oilist_with_keywords_only keyword thirdKeyword
test.oilist_with_keywords_only keyword thirdKeyword
return
}
@ -550,7 +554,7 @@ func @succeededOilistCustom(%arg0: i32, %arg1: i32, %arg2: i32) {
test.oilist_custom private (%arg0, %arg1 : i32, i32)
// CHECK: test.oilist_custom private(%[[ARG0]], %[[ARG1]] : i32, i32) nowait
test.oilist_custom private (%arg0, %arg1 : i32, i32) nowait
// CHECK: test.oilist_custom private(%arg0, %arg1 : i32, i32) nowait reduction (%arg1)
// CHECK: test.oilist_custom private(%arg0, %arg1 : i32, i32) reduction (%arg1) nowait
test.oilist_custom nowait reduction (%arg1) private (%arg0, %arg1 : i32, i32)
return
}

View File

@ -656,9 +656,12 @@ def CustomFormatFallbackOp : TEST_Op<"dialect_custom_format_fallback">;
// Ops related to OIList primitive
def OIListTrivial : TEST_Op<"oilist_with_keywords_only"> {
let arguments = (ins UnitAttr:$keyword, UnitAttr:$otherKeyword,
UnitAttr:$diffNameUnitAttrKeyword);
let assemblyFormat = [{
oilist( `keyword`
| `otherKeyword`) attr-dict
oilist( `keyword` $keyword
| `otherKeyword` $otherKeyword
| `thirdKeyword` $diffNameUnitAttrKeyword) attr-dict
}];
}
@ -690,8 +693,8 @@ def OIListCustom : TEST_Op<"oilist_custom", [AttrSizedOperandSegments]> {
UnitAttr:$nowait);
let assemblyFormat = [{
oilist( `private` `(` $arg0 `:` type($arg0) `)`
| `nowait`
| `reduction` custom<CustomOptionalOperand>($optOperand)
| `nowait` $nowait
) attr-dict
}];
}

View File

@ -207,6 +207,18 @@ public:
return llvm::zip(getLiteralElements(), getParsingElements());
}
/// If the parsing element is a single UnitAttr element, then it returns the
/// attribute variable. Otherwise, returns nullptr.
AttributeVariable *
getUnitAttrParsingElement(ArrayRef<FormatElement *> pelement) {
if (pelement.size() == 1) {
auto attrElem = dyn_cast<AttributeVariable>(pelement[0]);
if (attrElem && attrElem->isUnitAttr())
return attrElem;
}
return nullptr;
}
private:
/// A vector of `LiteralElement` objects. Each element stores the keyword
/// for one case of oilist element. For example, an oilist element along with
@ -684,7 +696,6 @@ const char *oilistParserCode = R"(
"oilist directive";
}
{0}Clause = true;
result.addAttribute("{0}", UnitAttr::get(parser.getContext()));
)";
namespace {
@ -778,9 +789,11 @@ static void genElementParserStorage(FormatElement *element, const Operator &op,
genElementParserStorage(childElement, op, body);
} else if (auto *oilist = dyn_cast<OIListElement>(element)) {
for (ArrayRef<FormatElement *> pelement : oilist->getParsingElements())
for (FormatElement *element : pelement)
genElementParserStorage(element, op, body);
for (ArrayRef<FormatElement *> pelement : oilist->getParsingElements()) {
if (!oilist->getUnitAttrParsingElement(pelement))
for (FormatElement *element : pelement)
genElementParserStorage(element, op, body);
}
} else if (auto *custom = dyn_cast<CustomDirective>(element)) {
for (FormatElement *paramElement : custom->getArguments())
@ -1180,11 +1193,16 @@ void OperationFormat::genElementParser(FormatElement *element, MethodBody &body,
body << "if (succeeded(parser.parseOptional";
genLiteralParser(lelement->getSpelling(), body);
body << ")) {\n";
StringRef attrName = lelement->getSpelling();
body << formatv(oilistParserCode, attrName);
inferredAttributes.insert(attrName);
for (FormatElement *el : pelement)
genElementParser(el, body, attrTypeCtx);
StringRef lelementName = lelement->getSpelling();
body << formatv(oilistParserCode, lelementName);
if (AttributeVariable *unitAttrElem =
oilist->getUnitAttrParsingElement(pelement)) {
body << " result.addAttribute(\"" << unitAttrElem->getVar()->name
<< "\", UnitAttr::get(parser.getContext()));\n";
} else {
for (FormatElement *el : pelement)
genElementParser(el, body, attrTypeCtx);
}
body << " } else ";
}
body << " {\n";
@ -1873,6 +1891,31 @@ static void genOptionalGroupPrinterAnchor(FormatElement *anchor,
});
}
void collect(FormatElement *element,
SmallVectorImpl<VariableElement *> &variables) {
TypeSwitch<FormatElement *>(element)
.Case([&](VariableElement *var) { variables.emplace_back(var); })
.Case([&](CustomDirective *ele) {
for (FormatElement *arg : ele->getArguments())
collect(arg, variables);
})
.Case([&](OptionalElement *ele) {
for (FormatElement *arg : ele->getThenElements())
collect(arg, variables);
for (FormatElement *arg : ele->getElseElements())
collect(arg, variables);
})
.Case([&](FunctionalTypeDirective *funcType) {
collect(funcType->getInputs(), variables);
collect(funcType->getResults(), variables);
})
.Case([&](OIListElement *oilist) {
for (ArrayRef<FormatElement *> arg : oilist->getParsingElements())
for (FormatElement *arg_ : arg)
collect(arg_, variables);
});
}
void OperationFormat::genElementPrinter(FormatElement *element,
MethodBody &body, Operator &op,
bool &shouldEmitSpace,
@ -1939,13 +1982,44 @@ void OperationFormat::genElementPrinter(FormatElement *element,
LiteralElement *lelement = std::get<0>(clause);
ArrayRef<FormatElement *> pelement = std::get<1>(clause);
body << " if ((*this)->hasAttrOfType<UnitAttr>(\""
<< lelement->getSpelling() << "\")) {\n";
SmallVector<VariableElement *> vars;
for (FormatElement *el : pelement)
collect(el, vars);
body << " if (false";
for (VariableElement *var : vars) {
TypeSwitch<FormatElement *>(var)
.Case([&](AttributeVariable *attrEle) {
body << " || " << op.getGetterName(attrEle->getVar()->name)
<< "Attr()";
})
.Case([&](OperandVariable *ele) {
if (ele->getVar()->isVariadic()) {
body << " || " << op.getGetterName(ele->getVar()->name)
<< "().size()";
} else {
body << " || " << op.getGetterName(ele->getVar()->name) << "()";
}
})
.Case([&](ResultVariable *ele) {
if (ele->getVar()->isVariadic()) {
body << " || " << op.getGetterName(ele->getVar()->name)
<< "().size()";
} else {
body << " || " << op.getGetterName(ele->getVar()->name) << "()";
}
})
.Case([&](RegionVariable *reg) {
body << " || " << op.getGetterName(reg->getVar()->name) << "()";
});
}
body << ") {\n";
genLiteralPrinter(lelement->getSpelling(), body, shouldEmitSpace,
lastWasPunctuation);
for (FormatElement *element : pelement) {
genElementPrinter(element, body, op, shouldEmitSpace,
lastWasPunctuation);
if (oilist->getUnitAttrParsingElement(pelement) == nullptr) {
for (FormatElement *element : pelement)
genElementPrinter(element, body, op, shouldEmitSpace,
lastWasPunctuation);
}
body << " }\n";
}
@ -2866,51 +2940,45 @@ OpFormatParser::parseOIListDirective(SMLoc loc, Context context) {
LogicalResult OpFormatParser::verifyOIListParsingElement(FormatElement *element,
SMLoc loc) {
return TypeSwitch<FormatElement *, LogicalResult>(element)
// Only optional attributes can be within an oilist parsing group.
.Case([&](AttributeVariable *attrEle) {
if (!attrEle->getVar()->attr.isOptional())
return emitError(loc, "only optional attributes can be used to "
"in an oilist parsing group");
return success();
})
// Only optional-like(i.e. variadic) operands can be within an oilist
// parsing group.
.Case([&](OperandVariable *ele) {
if (!ele->getVar()->isVariableLength())
return emitError(loc, "only variable length operands can be "
"used within an oilist parsing group");
return success();
})
// Only optional-like(i.e. variadic) results can be within an oilist
// parsing group.
.Case([&](ResultVariable *ele) {
if (!ele->getVar()->isVariableLength())
return emitError(loc, "only variable length results can be "
"used within an oilist parsing group");
return success();
})
.Case([&](RegionVariable *) {
// TODO: When ODS has proper support for marking "optional" regions, add
// a check here.
return success();
})
.Case([&](TypeDirective *ele) {
return verifyOIListParsingElement(ele->getArg(), loc);
})
.Case([&](FunctionalTypeDirective *ele) {
if (failed(verifyOIListParsingElement(ele->getInputs(), loc)))
return failure();
return verifyOIListParsingElement(ele->getResults(), loc);
})
// Literals, whitespace, and custom directives may be used.
.Case<LiteralElement, WhitespaceElement, CustomDirective,
FunctionalTypeDirective, OptionalElement>(
[&](FormatElement *) { return success(); })
.Default([&](FormatElement *) {
return emitError(loc, "only literals, types, and variables can be "
"used within an oilist group");
});
SmallVector<VariableElement *> vars;
collect(element, vars);
for (VariableElement *elem : vars) {
LogicalResult res =
TypeSwitch<FormatElement *, LogicalResult>(elem)
// Only optional attributes can be within an oilist parsing group.
.Case([&](AttributeVariable *attrEle) {
if (!attrEle->getVar()->attr.isOptional() &&
!attrEle->getVar()->attr.hasDefaultValue())
return emitError(loc, "only optional attributes can be used in "
"an oilist parsing group");
return success();
})
// Only optional-like(i.e. variadic) operands can be within an
// oilist parsing group.
.Case([&](OperandVariable *ele) {
if (!ele->getVar()->isVariableLength())
return emitError(loc, "only variable length operands can be "
"used within an oilist parsing group");
return success();
})
// Only optional-like(i.e. variadic) results can be within an oilist
// parsing group.
.Case([&](ResultVariable *ele) {
if (!ele->getVar()->isVariableLength())
return emitError(loc, "only variable length results can be "
"used within an oilist parsing group");
return success();
})
.Case([&](RegionVariable *) { return success(); })
.Default([&](FormatElement *) {
return emitError(loc,
"only literals, types, and variables can be "
"used within an oilist group");
});
if (failed(res))
return failure();
}
return success();
}
FailureOr<FormatElement *> OpFormatParser::parseTypeDirective(SMLoc loc,