[mlir] Printing oilist element
This patch attempts to deduce when the oilist element must be printed based on the optional arguments to it. This especially helps creating an operation accurately because with the current implementation, the inferred unit attributes must be manually added to print the clauses appropriately. Reviewed By: Mogball Differential Revision: https://reviews.llvm.org/D121579
This commit is contained in:
parent
f863df9a05
commit
ddc90da478
|
@ -199,7 +199,7 @@ def SectionsOp : OpenMP_Op<"sections", [AttrSizedOperandSegments]> {
|
|||
$allocate_vars, type($allocate_vars),
|
||||
$allocators_vars, type($allocators_vars)
|
||||
) `)`
|
||||
| `nowait`
|
||||
| `nowait` $nowait
|
||||
) $region attr-dict
|
||||
}];
|
||||
|
||||
|
@ -438,7 +438,7 @@ def TargetOp : OpenMP_Op<"target",[AttrSizedOperandSegments]> {
|
|||
oilist( `if` `(` $if_expr `)`
|
||||
| `device` `(` $device `:` type($device) `)`
|
||||
| `thread_limit` `(` $thread_limit `:` type($thread_limit) `)`
|
||||
| `nowait`
|
||||
| `nowait` $nowait
|
||||
) $region attr-dict
|
||||
}];
|
||||
}
|
||||
|
|
|
@ -498,6 +498,10 @@ func @succeededOilistTrivial() {
|
|||
test.oilist_with_keywords_only keyword otherKeyword
|
||||
// CHECK: test.oilist_with_keywords_only keyword otherKeyword
|
||||
test.oilist_with_keywords_only otherKeyword keyword
|
||||
// CHECK: test.oilist_with_keywords_only thirdKeyword
|
||||
test.oilist_with_keywords_only thirdKeyword
|
||||
// CHECK: test.oilist_with_keywords_only keyword thirdKeyword
|
||||
test.oilist_with_keywords_only keyword thirdKeyword
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -550,7 +554,7 @@ func @succeededOilistCustom(%arg0: i32, %arg1: i32, %arg2: i32) {
|
|||
test.oilist_custom private (%arg0, %arg1 : i32, i32)
|
||||
// CHECK: test.oilist_custom private(%[[ARG0]], %[[ARG1]] : i32, i32) nowait
|
||||
test.oilist_custom private (%arg0, %arg1 : i32, i32) nowait
|
||||
// CHECK: test.oilist_custom private(%arg0, %arg1 : i32, i32) nowait reduction (%arg1)
|
||||
// CHECK: test.oilist_custom private(%arg0, %arg1 : i32, i32) reduction (%arg1) nowait
|
||||
test.oilist_custom nowait reduction (%arg1) private (%arg0, %arg1 : i32, i32)
|
||||
return
|
||||
}
|
||||
|
|
|
@ -656,9 +656,12 @@ def CustomFormatFallbackOp : TEST_Op<"dialect_custom_format_fallback">;
|
|||
|
||||
// Ops related to OIList primitive
|
||||
def OIListTrivial : TEST_Op<"oilist_with_keywords_only"> {
|
||||
let arguments = (ins UnitAttr:$keyword, UnitAttr:$otherKeyword,
|
||||
UnitAttr:$diffNameUnitAttrKeyword);
|
||||
let assemblyFormat = [{
|
||||
oilist( `keyword`
|
||||
| `otherKeyword`) attr-dict
|
||||
oilist( `keyword` $keyword
|
||||
| `otherKeyword` $otherKeyword
|
||||
| `thirdKeyword` $diffNameUnitAttrKeyword) attr-dict
|
||||
}];
|
||||
}
|
||||
|
||||
|
@ -690,8 +693,8 @@ def OIListCustom : TEST_Op<"oilist_custom", [AttrSizedOperandSegments]> {
|
|||
UnitAttr:$nowait);
|
||||
let assemblyFormat = [{
|
||||
oilist( `private` `(` $arg0 `:` type($arg0) `)`
|
||||
| `nowait`
|
||||
| `reduction` custom<CustomOptionalOperand>($optOperand)
|
||||
| `nowait` $nowait
|
||||
) attr-dict
|
||||
}];
|
||||
}
|
||||
|
|
|
@ -207,6 +207,18 @@ public:
|
|||
return llvm::zip(getLiteralElements(), getParsingElements());
|
||||
}
|
||||
|
||||
/// If the parsing element is a single UnitAttr element, then it returns the
|
||||
/// attribute variable. Otherwise, returns nullptr.
|
||||
AttributeVariable *
|
||||
getUnitAttrParsingElement(ArrayRef<FormatElement *> pelement) {
|
||||
if (pelement.size() == 1) {
|
||||
auto attrElem = dyn_cast<AttributeVariable>(pelement[0]);
|
||||
if (attrElem && attrElem->isUnitAttr())
|
||||
return attrElem;
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
private:
|
||||
/// A vector of `LiteralElement` objects. Each element stores the keyword
|
||||
/// for one case of oilist element. For example, an oilist element along with
|
||||
|
@ -684,7 +696,6 @@ const char *oilistParserCode = R"(
|
|||
"oilist directive";
|
||||
}
|
||||
{0}Clause = true;
|
||||
result.addAttribute("{0}", UnitAttr::get(parser.getContext()));
|
||||
)";
|
||||
|
||||
namespace {
|
||||
|
@ -778,9 +789,11 @@ static void genElementParserStorage(FormatElement *element, const Operator &op,
|
|||
genElementParserStorage(childElement, op, body);
|
||||
|
||||
} else if (auto *oilist = dyn_cast<OIListElement>(element)) {
|
||||
for (ArrayRef<FormatElement *> pelement : oilist->getParsingElements())
|
||||
for (FormatElement *element : pelement)
|
||||
genElementParserStorage(element, op, body);
|
||||
for (ArrayRef<FormatElement *> pelement : oilist->getParsingElements()) {
|
||||
if (!oilist->getUnitAttrParsingElement(pelement))
|
||||
for (FormatElement *element : pelement)
|
||||
genElementParserStorage(element, op, body);
|
||||
}
|
||||
|
||||
} else if (auto *custom = dyn_cast<CustomDirective>(element)) {
|
||||
for (FormatElement *paramElement : custom->getArguments())
|
||||
|
@ -1180,11 +1193,16 @@ void OperationFormat::genElementParser(FormatElement *element, MethodBody &body,
|
|||
body << "if (succeeded(parser.parseOptional";
|
||||
genLiteralParser(lelement->getSpelling(), body);
|
||||
body << ")) {\n";
|
||||
StringRef attrName = lelement->getSpelling();
|
||||
body << formatv(oilistParserCode, attrName);
|
||||
inferredAttributes.insert(attrName);
|
||||
for (FormatElement *el : pelement)
|
||||
genElementParser(el, body, attrTypeCtx);
|
||||
StringRef lelementName = lelement->getSpelling();
|
||||
body << formatv(oilistParserCode, lelementName);
|
||||
if (AttributeVariable *unitAttrElem =
|
||||
oilist->getUnitAttrParsingElement(pelement)) {
|
||||
body << " result.addAttribute(\"" << unitAttrElem->getVar()->name
|
||||
<< "\", UnitAttr::get(parser.getContext()));\n";
|
||||
} else {
|
||||
for (FormatElement *el : pelement)
|
||||
genElementParser(el, body, attrTypeCtx);
|
||||
}
|
||||
body << " } else ";
|
||||
}
|
||||
body << " {\n";
|
||||
|
@ -1873,6 +1891,31 @@ static void genOptionalGroupPrinterAnchor(FormatElement *anchor,
|
|||
});
|
||||
}
|
||||
|
||||
void collect(FormatElement *element,
|
||||
SmallVectorImpl<VariableElement *> &variables) {
|
||||
TypeSwitch<FormatElement *>(element)
|
||||
.Case([&](VariableElement *var) { variables.emplace_back(var); })
|
||||
.Case([&](CustomDirective *ele) {
|
||||
for (FormatElement *arg : ele->getArguments())
|
||||
collect(arg, variables);
|
||||
})
|
||||
.Case([&](OptionalElement *ele) {
|
||||
for (FormatElement *arg : ele->getThenElements())
|
||||
collect(arg, variables);
|
||||
for (FormatElement *arg : ele->getElseElements())
|
||||
collect(arg, variables);
|
||||
})
|
||||
.Case([&](FunctionalTypeDirective *funcType) {
|
||||
collect(funcType->getInputs(), variables);
|
||||
collect(funcType->getResults(), variables);
|
||||
})
|
||||
.Case([&](OIListElement *oilist) {
|
||||
for (ArrayRef<FormatElement *> arg : oilist->getParsingElements())
|
||||
for (FormatElement *arg_ : arg)
|
||||
collect(arg_, variables);
|
||||
});
|
||||
}
|
||||
|
||||
void OperationFormat::genElementPrinter(FormatElement *element,
|
||||
MethodBody &body, Operator &op,
|
||||
bool &shouldEmitSpace,
|
||||
|
@ -1939,13 +1982,44 @@ void OperationFormat::genElementPrinter(FormatElement *element,
|
|||
LiteralElement *lelement = std::get<0>(clause);
|
||||
ArrayRef<FormatElement *> pelement = std::get<1>(clause);
|
||||
|
||||
body << " if ((*this)->hasAttrOfType<UnitAttr>(\""
|
||||
<< lelement->getSpelling() << "\")) {\n";
|
||||
SmallVector<VariableElement *> vars;
|
||||
for (FormatElement *el : pelement)
|
||||
collect(el, vars);
|
||||
body << " if (false";
|
||||
for (VariableElement *var : vars) {
|
||||
TypeSwitch<FormatElement *>(var)
|
||||
.Case([&](AttributeVariable *attrEle) {
|
||||
body << " || " << op.getGetterName(attrEle->getVar()->name)
|
||||
<< "Attr()";
|
||||
})
|
||||
.Case([&](OperandVariable *ele) {
|
||||
if (ele->getVar()->isVariadic()) {
|
||||
body << " || " << op.getGetterName(ele->getVar()->name)
|
||||
<< "().size()";
|
||||
} else {
|
||||
body << " || " << op.getGetterName(ele->getVar()->name) << "()";
|
||||
}
|
||||
})
|
||||
.Case([&](ResultVariable *ele) {
|
||||
if (ele->getVar()->isVariadic()) {
|
||||
body << " || " << op.getGetterName(ele->getVar()->name)
|
||||
<< "().size()";
|
||||
} else {
|
||||
body << " || " << op.getGetterName(ele->getVar()->name) << "()";
|
||||
}
|
||||
})
|
||||
.Case([&](RegionVariable *reg) {
|
||||
body << " || " << op.getGetterName(reg->getVar()->name) << "()";
|
||||
});
|
||||
}
|
||||
|
||||
body << ") {\n";
|
||||
genLiteralPrinter(lelement->getSpelling(), body, shouldEmitSpace,
|
||||
lastWasPunctuation);
|
||||
for (FormatElement *element : pelement) {
|
||||
genElementPrinter(element, body, op, shouldEmitSpace,
|
||||
lastWasPunctuation);
|
||||
if (oilist->getUnitAttrParsingElement(pelement) == nullptr) {
|
||||
for (FormatElement *element : pelement)
|
||||
genElementPrinter(element, body, op, shouldEmitSpace,
|
||||
lastWasPunctuation);
|
||||
}
|
||||
body << " }\n";
|
||||
}
|
||||
|
@ -2866,51 +2940,45 @@ OpFormatParser::parseOIListDirective(SMLoc loc, Context context) {
|
|||
|
||||
LogicalResult OpFormatParser::verifyOIListParsingElement(FormatElement *element,
|
||||
SMLoc loc) {
|
||||
return TypeSwitch<FormatElement *, LogicalResult>(element)
|
||||
// Only optional attributes can be within an oilist parsing group.
|
||||
.Case([&](AttributeVariable *attrEle) {
|
||||
if (!attrEle->getVar()->attr.isOptional())
|
||||
return emitError(loc, "only optional attributes can be used to "
|
||||
"in an oilist parsing group");
|
||||
return success();
|
||||
})
|
||||
// Only optional-like(i.e. variadic) operands can be within an oilist
|
||||
// parsing group.
|
||||
.Case([&](OperandVariable *ele) {
|
||||
if (!ele->getVar()->isVariableLength())
|
||||
return emitError(loc, "only variable length operands can be "
|
||||
"used within an oilist parsing group");
|
||||
return success();
|
||||
})
|
||||
// Only optional-like(i.e. variadic) results can be within an oilist
|
||||
// parsing group.
|
||||
.Case([&](ResultVariable *ele) {
|
||||
if (!ele->getVar()->isVariableLength())
|
||||
return emitError(loc, "only variable length results can be "
|
||||
"used within an oilist parsing group");
|
||||
return success();
|
||||
})
|
||||
.Case([&](RegionVariable *) {
|
||||
// TODO: When ODS has proper support for marking "optional" regions, add
|
||||
// a check here.
|
||||
return success();
|
||||
})
|
||||
.Case([&](TypeDirective *ele) {
|
||||
return verifyOIListParsingElement(ele->getArg(), loc);
|
||||
})
|
||||
.Case([&](FunctionalTypeDirective *ele) {
|
||||
if (failed(verifyOIListParsingElement(ele->getInputs(), loc)))
|
||||
return failure();
|
||||
return verifyOIListParsingElement(ele->getResults(), loc);
|
||||
})
|
||||
// Literals, whitespace, and custom directives may be used.
|
||||
.Case<LiteralElement, WhitespaceElement, CustomDirective,
|
||||
FunctionalTypeDirective, OptionalElement>(
|
||||
[&](FormatElement *) { return success(); })
|
||||
.Default([&](FormatElement *) {
|
||||
return emitError(loc, "only literals, types, and variables can be "
|
||||
"used within an oilist group");
|
||||
});
|
||||
SmallVector<VariableElement *> vars;
|
||||
collect(element, vars);
|
||||
for (VariableElement *elem : vars) {
|
||||
LogicalResult res =
|
||||
TypeSwitch<FormatElement *, LogicalResult>(elem)
|
||||
// Only optional attributes can be within an oilist parsing group.
|
||||
.Case([&](AttributeVariable *attrEle) {
|
||||
if (!attrEle->getVar()->attr.isOptional() &&
|
||||
!attrEle->getVar()->attr.hasDefaultValue())
|
||||
return emitError(loc, "only optional attributes can be used in "
|
||||
"an oilist parsing group");
|
||||
return success();
|
||||
})
|
||||
// Only optional-like(i.e. variadic) operands can be within an
|
||||
// oilist parsing group.
|
||||
.Case([&](OperandVariable *ele) {
|
||||
if (!ele->getVar()->isVariableLength())
|
||||
return emitError(loc, "only variable length operands can be "
|
||||
"used within an oilist parsing group");
|
||||
return success();
|
||||
})
|
||||
// Only optional-like(i.e. variadic) results can be within an oilist
|
||||
// parsing group.
|
||||
.Case([&](ResultVariable *ele) {
|
||||
if (!ele->getVar()->isVariableLength())
|
||||
return emitError(loc, "only variable length results can be "
|
||||
"used within an oilist parsing group");
|
||||
return success();
|
||||
})
|
||||
.Case([&](RegionVariable *) { return success(); })
|
||||
.Default([&](FormatElement *) {
|
||||
return emitError(loc,
|
||||
"only literals, types, and variables can be "
|
||||
"used within an oilist group");
|
||||
});
|
||||
if (failed(res))
|
||||
return failure();
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
FailureOr<FormatElement *> OpFormatParser::parseTypeDirective(SMLoc loc,
|
||||
|
|
Loading…
Reference in New Issue