[mlir-tblgen] Support `either` in Tablegen DRR.

Add a new directive `either` to specify the operands can be matched in either order

Reviewed By: jpienaar, Mogball

Differential Revision: https://reviews.llvm.org/D110666
This commit is contained in:
Chia-hung Duan 2021-11-08 22:56:40 +00:00
parent 1b409df613
commit 2d99c815d7
7 changed files with 244 additions and 50 deletions

View File

@ -774,6 +774,23 @@ Explicitly-specified return types will take precedence over return types
inferred from op traits or user-defined builders. The return types of values
replacing root op results cannot be overridden.
### `either`
The `either` directive is used to specify the operands may be matched in either
order.
```tablegen
def : Pat<(TwoArgOp (either $firstArg, (AnOp $secondArg))),
(...)>;
```
The above pattern will accept either `"test.TwoArgOp"(%I32Arg, %AnOpArg)` and
`"test.TwoArgOp"(%AnOpArg, %I32Arg)`.
Only operand is supported with `either` and note that an operation with
`Commutative` trait doesn't imply that it'll have the same behavior than
`either` while pattern matching.
## Debugging Tips
### Run `mlir-tblgen` to see the generated content

View File

@ -2730,6 +2730,21 @@ def location;
def returnType;
// Directive used to specify the operands may be matched in either order. When
// two adjacents are marked with `either`, it'll try to match the operands in
// either ordering of constraints. Example:
//
// ```
// (TwoArgOp (either $firstArg, (AnOp $secondArg)))
// ```
// The above pattern will accept either `"test.TwoArgOp"(%I32Arg, %AnOpArg)` and
// `"test.TwoArgOp"(%AnOpArg, %I32Arg)`.
//
// Only operand is supported with `either` and note that an operation with
// `Commutative` trait doesn't imply that it'll have the same behavior than
// `either` while pattern matching.
def either;
//===----------------------------------------------------------------------===//
// Attribute and Type generation
//===----------------------------------------------------------------------===//

View File

@ -186,6 +186,9 @@ public:
// Returns true if this DAG node is wrapping native code call.
bool isNativeCodeCall() const;
// Returns whether this DAG is an `either` specifier.
bool isEither() const;
// Returns true if this DAG node is an operation.
bool isOperation() const;

View File

@ -113,7 +113,7 @@ bool DagNode::isNativeCodeCall() const {
bool DagNode::isOperation() const {
return !isNativeCodeCall() && !isReplaceWithValue() &&
!isLocationDirective() && !isReturnTypeDirective();
!isLocationDirective() && !isReturnTypeDirective() && !isEither();
}
llvm::StringRef DagNode::getNativeCodeTemplate() const {
@ -142,7 +142,9 @@ Operator &DagNode::getDialectOp(RecordOperatorMap *mapper) const {
}
int DagNode::getNumOps() const {
int count = isReplaceWithValue() ? 0 : 1;
// We want to get number of operations recursively involved in the DAG tree.
// All other directives should be excluded.
int count = isOperation() ? 1 : 0;
for (int i = 0, e = getNumArgs(); i != e; ++i) {
if (auto child = getArgAsNestedDag(i))
count += child.getNumOps();
@ -184,6 +186,11 @@ bool DagNode::isReturnTypeDirective() const {
return dagOpDef->getName() == "returnType";
}
bool DagNode::isEither() const {
auto *dagOpDef = cast<llvm::DefInit>(node->getOperator())->getDef();
return dagOpDef->getName() == "either";
}
void DagNode::print(raw_ostream &os) const {
if (node)
node->print(os);
@ -764,22 +771,25 @@ void Pattern::collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap,
if (tree.isOperation()) {
auto &op = getDialectOp(tree);
auto numOpArgs = op.getNumArgs();
int numEither = 0;
// The pattern might have trailing directives.
// We need to exclude the trailing directives and `either` directive groups
// two operands of the operation.
int numDirectives = 0;
for (int i = numTreeArgs - 1; i >= 0; --i) {
if (auto dagArg = tree.getArgAsNestedDag(i)) {
if (dagArg.isLocationDirective() || dagArg.isReturnTypeDirective())
++numDirectives;
else
break;
else if (dagArg.isEither())
++numEither;
}
}
if (numOpArgs != numTreeArgs - numDirectives) {
auto err = formatv("op '{0}' argument number mismatch: "
"{1} in pattern vs. {2} in definition",
op.getOperationName(), numTreeArgs, numOpArgs);
if (numOpArgs != numTreeArgs - numDirectives + numEither) {
auto err =
formatv("op '{0}' argument number mismatch: "
"{1} in pattern vs. {2} in definition",
op.getOperationName(), numTreeArgs + numEither, numOpArgs);
PrintFatalError(&def, err);
}
@ -791,10 +801,30 @@ void Pattern::collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap,
verifyBind(infoMap.bindOpResult(treeName, op), treeName);
}
for (int i = 0; i != numTreeArgs; ++i) {
// The operand in `either` DAG should be bound to the operation in the
// parent DagNode.
auto collectSymbolInEither = [&](DagNode parent, DagNode tree,
int &opArgIdx) {
for (int i = 0; i < tree.getNumArgs(); ++i, ++opArgIdx) {
if (DagNode subTree = tree.getArgAsNestedDag(i)) {
collectBoundSymbols(subTree, infoMap, isSrcPattern);
} else {
auto argName = tree.getArgName(i);
if (!argName.empty() && argName != "_")
verifyBind(infoMap.bindOpArgument(parent, argName, op, opArgIdx),
argName);
}
}
};
for (int i = 0, opArgIdx = 0; i != numTreeArgs; ++i, ++opArgIdx) {
if (auto treeArg = tree.getArgAsNestedDag(i)) {
// This DAG node argument is a DAG node itself. Go inside recursively.
collectBoundSymbols(treeArg, infoMap, isSrcPattern);
if (treeArg.isEither()) {
collectSymbolInEither(tree, treeArg, opArgIdx);
} else {
// This DAG node argument is a DAG node itself. Go inside recursively.
collectBoundSymbols(treeArg, infoMap, isSrcPattern);
}
continue;
}
@ -806,7 +836,7 @@ void Pattern::collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap,
if (!treeArgName.empty() && treeArgName != "_") {
LLVM_DEBUG(llvm::dbgs() << "found symbol bound to op argument: "
<< treeArgName << '\n');
verifyBind(infoMap.bindOpArgument(tree, treeArgName, op, i),
verifyBind(infoMap.bindOpArgument(tree, treeArgName, op, opArgIdx),
treeArgName);
}
}

View File

@ -1328,6 +1328,30 @@ def : Pat<(OneI32ResultOp),
(replaceWithValue $results__2),
ConstantAttr<I32Attr, "2">)>;
//===----------------------------------------------------------------------===//
// Test Patterns (either)
def TestEitherOpA : TEST_Op<"either_op_a"> {
let arguments = (ins AnyInteger:$arg0, AnyInteger:$arg1, AnyInteger:$arg2);
let results = (outs I32:$output);
}
def TestEitherOpB : TEST_Op<"either_op_b"> {
let arguments = (ins AnyInteger:$arg0);
let results = (outs I32:$output);
}
def : Pat<(TestEitherOpA (either I32:$arg1, I16:$arg2), $_),
(TestEitherOpB $arg2)>;
def : Pat<(TestEitherOpA (either (TestEitherOpB I32:$arg1), I16:$arg2), $_),
(TestEitherOpB $arg2)>;
def : Pat<(TestEitherOpA (either (TestEitherOpB I32:$arg1),
(TestEitherOpB I16:$arg2)),
$_),
(TestEitherOpB $arg2)>;
//===----------------------------------------------------------------------===//
// Test Patterns (Location)

View File

@ -531,6 +531,40 @@ func @redundantTest(%arg0: i32) -> i32 {
return %0 : i32
}
//===----------------------------------------------------------------------===//
// Test either directive
//===----------------------------------------------------------------------===//
// CHECK: @either_dag_leaf_only
func @either_dag_leaf_only_1(%arg0 : i32, %arg1 : i16, %arg2 : i8) -> () {
// CHECK: "test.either_op_b"(%arg1) : (i16) -> i32
%0 = "test.either_op_a"(%arg0, %arg1, %arg2) : (i32, i16, i8) -> i32
// CHECK: "test.either_op_b"(%arg1) : (i16) -> i32
%1 = "test.either_op_a"(%arg1, %arg0, %arg2) : (i16, i32, i8) -> i32
return
}
// CHECK: @either_dag_leaf_dag_node
func @either_dag_leaf_dag_node(%arg0 : i32, %arg1 : i16, %arg2 : i8) -> () {
%0 = "test.either_op_b"(%arg0) : (i32) -> i32
// CHECK: "test.either_op_b"(%arg1) : (i16) -> i32
%1 = "test.either_op_a"(%0, %arg1, %arg2) : (i32, i16, i8) -> i32
// CHECK: "test.either_op_b"(%arg1) : (i16) -> i32
%2 = "test.either_op_a"(%arg1, %0, %arg2) : (i16, i32, i8) -> i32
return
}
// CHECK: @either_dag_node_dag_node
func @either_dag_node_dag_node(%arg0 : i32, %arg1 : i16, %arg2 : i8) -> () {
%0 = "test.either_op_b"(%arg0) : (i32) -> i32
%1 = "test.either_op_b"(%arg1) : (i16) -> i32
// CHECK: "test.either_op_b"(%arg1) : (i16) -> i32
%2 = "test.either_op_a"(%0, %1, %arg2) : (i32, i32, i8) -> i32
// CHECK: "test.either_op_b"(%arg1) : (i16) -> i32
%3 = "test.either_op_a"(%1, %0, %arg2) : (i32, i32, i8) -> i32
return
}
//===----------------------------------------------------------------------===//
// Test that ops without type deduction can be created with type builders.
//===----------------------------------------------------------------------===//

View File

@ -117,10 +117,17 @@ private:
void emitOpMatch(DagNode tree, StringRef opName, int depth);
// Emits C++ statements for matching the `argIndex`-th argument of the given
// DAG `tree` as an operand. operandIndex is the index in the DAG excluding
// the preceding attributes.
void emitOperandMatch(DagNode tree, StringRef opName, int argIndex,
int operandIndex, int depth);
// DAG `tree` as an operand. `operandName` and `operandMatcher` indicate the
// bound name and the constraint of the operand respectively.
void emitOperandMatch(DagNode tree, StringRef opName, StringRef operandName,
DagLeaf operandMatcher, StringRef argName,
int argIndex);
// Emits C++ statements for matching the operands which can be matched in
// either order.
void emitEitherOperandMatch(DagNode tree, DagNode eitherArgTree,
StringRef opName, int argIndex, int &operandIndex,
int depth);
// Emits C++ statements for matching the `argIndex`-th argument of the given
// DAG `tree` as an attribute.
@ -470,6 +477,9 @@ void PatternEmitter::emitNativeCodeMatch(DagNode tree, StringRef opName,
for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
std::string argName = formatv("arg{0}_{1}", depth, i);
if (DagNode argTree = tree.getArgAsNestedDag(i)) {
if (argTree.isEither())
PrintFatalError(loc, "NativeCodeCall cannot have `either` operands");
os << "Value " << argName << ";\n";
} else {
auto leaf = tree.getArgAsLeaf(i);
@ -584,12 +594,6 @@ void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) {
formatv("\"{0} is not {1} type\"", castedName,
op.getQualCppClassName()));
if (tree.getNumArgs() != op.getNumArgs())
PrintFatalError(loc, formatv("op '{0}' argument number mismatch: {1} in "
"pattern vs. {2} in definition",
op.getOperationName(), tree.getNumArgs(),
op.getNumArgs()));
// If the operand's name is set, set to that variable.
auto name = tree.getSymbol();
if (!name.empty())
@ -601,6 +605,11 @@ void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) {
// Handle nested DAG construct first
if (DagNode argTree = tree.getArgAsNestedDag(i)) {
if (argTree.isEither()) {
emitEitherOperandMatch(tree, argTree, castedName, i, nextOperand,
depth);
continue;
}
if (auto *operand = opArg.dyn_cast<NamedTypeConstraint *>()) {
if (operand->isVariableLength()) {
auto error = formatv("use nested DAG construct to match op {0}'s "
@ -609,6 +618,7 @@ void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) {
PrintFatalError(loc, error);
}
}
os << "{\n";
// Attributes don't count for getODSOperands.
@ -618,9 +628,10 @@ void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) {
"(*{1}.getODSOperands({2}).begin()).getDefiningOp();\n",
argName, castedName, nextOperand);
// Null check of operand's definingOp
emitMatchCheck(castedName, /*matchStr=*/argName,
formatv("\"Operand {0} of {1} has null definingOp\"",
nextOperand++, castedName));
emitMatchCheck(
castedName, /*matchStr=*/argName,
formatv("\"There's no operation that defines operand {0} of {1}\"",
nextOperand++, castedName));
emitMatch(argTree, argName, depth + 1);
os << formatv("tblgen_ops.push_back({0});\n", argName);
os.unindent() << "}\n";
@ -629,8 +640,12 @@ void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) {
// Next handle DAG leaf: operand or attribute
if (opArg.is<NamedTypeConstraint *>()) {
// emitOperandMatch's argument indexing counts attributes.
emitOperandMatch(tree, castedName, i, nextOperand, depth);
auto operandName =
formatv("{0}.getODSOperands({1})", castedName, nextOperand);
emitOperandMatch(tree, castedName, operandName.str(),
/*operandMatcher=*/tree.getArgAsLeaf(i),
/*argName=*/tree.getArgName(i),
/*argIndex=*/i);
++nextOperand;
} else if (opArg.is<NamedAttribute *>()) {
emitAttributeMatch(tree, opName, i, depth);
@ -644,24 +659,23 @@ void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) {
}
void PatternEmitter::emitOperandMatch(DagNode tree, StringRef opName,
int argIndex, int operandIndex,
int depth) {
StringRef operandName,
DagLeaf operandMatcher, StringRef argName,
int argIndex) {
Operator &op = tree.getDialectOp(opMap);
auto *operand = op.getArg(argIndex).get<NamedTypeConstraint *>();
auto matcher = tree.getArgAsLeaf(argIndex);
// If a constraint is specified, we need to generate C++ statements to
// check the constraint.
if (!matcher.isUnspecified()) {
if (!matcher.isOperandMatcher()) {
if (!operandMatcher.isUnspecified()) {
if (!operandMatcher.isOperandMatcher())
PrintFatalError(
loc, formatv("the {1}-th argument of op '{0}' should be an operand",
op.getOperationName(), argIndex + 1));
}
// Only need to verify if the matcher's type is different from the one
// of op definition.
Constraint constraint = matcher.getAsConstraint();
Constraint constraint = operandMatcher.getAsConstraint();
if (operand->constraint != constraint) {
if (operand->isVariableLength()) {
auto error = formatv(
@ -669,36 +683,93 @@ void PatternEmitter::emitOperandMatch(DagNode tree, StringRef opName,
op.getOperationName(), argIndex);
PrintFatalError(loc, error);
}
auto self = formatv("(*{0}.getODSOperands({1}).begin()).getType()",
opName, operandIndex);
auto self = formatv("(*{0}.begin()).getType()", operandName);
StringRef verifier = staticMatcherHelper.getVerifierName(constraint);
emitStaticVerifierCall(
verifier, opName, self.str(),
formatv(
"\"operand {0} of op '{1}' failed to satisfy constraint: '{2}'\"",
operandIndex, op.getOperationName(),
operand - op.operand_begin(), op.getOperationName(),
escapeString(constraint.getSummary()))
.str());
}
}
// Capture the value
auto name = tree.getArgName(argIndex);
// `$_` is a special symbol to ignore op argument matching.
if (!name.empty() && name != "_") {
// We need to subtract the number of attributes before this operand to get
// the index in the operand list.
auto numPrevAttrs = std::count_if(
op.arg_begin(), op.arg_begin() + argIndex,
[](const Argument &arg) { return arg.is<NamedAttribute *>(); });
auto res = symbolInfoMap.findBoundSymbol(name, tree, op, argIndex);
os << formatv("{0} = {1}.getODSOperands({2});\n",
res->second.getVarName(name), opName,
argIndex - numPrevAttrs);
if (!argName.empty() && argName != "_") {
auto res = symbolInfoMap.findBoundSymbol(argName, tree, op, argIndex);
os << formatv("{0} = {1};\n", res->second.getVarName(argName), operandName);
}
}
void PatternEmitter::emitEitherOperandMatch(DagNode tree, DagNode eitherArgTree,
StringRef opName, int argIndex,
int &operandIndex, int depth) {
constexpr int numEitherArgs = 2;
if (eitherArgTree.getNumArgs() != numEitherArgs)
PrintFatalError(loc, "`either` only supports grouping two operands");
Operator &op = tree.getDialectOp(opMap);
std::string codeBuffer;
llvm::raw_string_ostream tblgenOps(codeBuffer);
std::string lambda = formatv("eitherLambda{0}", depth);
os << formatv("auto {0} = [&](OperandRange v0, OperandRange v1) {{\n",
lambda);
os.indent();
for (int i = 0; i < numEitherArgs; ++i, ++argIndex) {
if (DagNode argTree = eitherArgTree.getArgAsNestedDag(i)) {
if (argTree.isEither())
PrintFatalError(loc, "either cannot be nested");
std::string argName = formatv("local_op_{0}", i).str();
os << formatv("auto {0} = (*v{1}.begin()).getDefiningOp();\n", argName,
i);
emitMatchCheck(
opName, /*matchStr=*/argName,
formatv("\"There's no operation that defines operand {0} of {1}\"",
operandIndex++, opName));
emitMatch(argTree, argName, depth + 1);
// `tblgen_ops` is used to collect the matched operations. In either, we
// need to queue the operation only if the matching success. Thus we emit
// the code at the end.
tblgenOps << formatv("tblgen_ops.push_back({0});\n", argName);
} else if (op.getArg(argIndex).is<NamedTypeConstraint *>()) {
emitOperandMatch(tree, opName, /*operandName=*/formatv("v{0}", i).str(),
/*operandMatcher=*/eitherArgTree.getArgAsLeaf(i),
/*argName=*/eitherArgTree.getArgName(i), argIndex);
++operandIndex;
} else {
PrintFatalError(loc, "either can only be applied on operand");
}
}
os << tblgenOps.str();
os << "return success();\n";
os.unindent() << "};\n";
os << "{\n";
os.indent();
os << formatv("auto eitherOperand0 = {0}.getODSOperands({1});\n", opName,
operandIndex - 2);
os << formatv("auto eitherOperand1 = {0}.getODSOperands({1});\n", opName,
operandIndex - 1);
os << formatv("if(failed({0}(eitherOperand0, eitherOperand1)) && "
"failed({0}(eitherOperand1, "
"eitherOperand0)))\n",
lambda);
os.indent() << "return failure();\n";
os.unindent().unindent() << "}\n";
}
void PatternEmitter::emitAttributeMatch(DagNode tree, StringRef opName,
int argIndex, int depth) {
Operator &op = tree.getDialectOp(opMap);