From 31a003dc3c17eba12f7dcb4dda626bc7ca41c0c5 Mon Sep 17 00:00:00 2001 From: Chris Lattner Date: Fri, 23 Aug 2019 10:35:24 -0700 Subject: [PATCH] Introduce the ability for "isolated from above" ops to introduce shadowing names for the basic block arguments in their body. PiperOrigin-RevId: 265084627 --- mlir/include/mlir/IR/OpImplementation.h | 7 +++ mlir/lib/IR/AsmPrinter.cpp | 61 ++++++++++++++++++++--- mlir/test/IR/parser.mlir | 20 ++++++-- mlir/test/lib/TestDialect/TestDialect.cpp | 7 +++ mlir/test/lib/TestDialect/TestOps.td | 1 + 5 files changed, 85 insertions(+), 11 deletions(-) diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h index 99c1ff553918..c4e87ce3eef2 100644 --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -85,6 +85,13 @@ public: virtual void printRegion(Region &blocks, bool printEntryBlockArgs = true, bool printBlockTerminators = true) = 0; + /// Renumber the arguments for the specified region to the same names as the + /// SSA values in namesToUse. This may only be used for IsolatedFromAbove + /// operations. If any entry in namesToUse is null, the corresponding + /// argument name is left alone. + virtual void shadowRegionArgs(Region ®ion, + ArrayRef namesToUse) = 0; + /// Prints an affine map of SSA ids, where SSA id names are used in place /// of dims/symbols. /// Operand values must come from single-result sources, and be valid diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index 82f2c9970a6b..9da922cd6212 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -1244,6 +1244,12 @@ public: os.indent(currentIndent) << "}"; } + /// Renumber the arguments for the specified region to the same names as the + /// SSA values in namesToUse. This may only be used for IsolatedFromAbove + /// operations. If any entry in namesToUse is null, the corresponding + /// argument name is left alone. + void shadowRegionArgs(Region ®ion, ArrayRef namesToUse) override; + void printAffineMapOfSSAIds(AffineMapAttr mapAttr, ArrayRef operands) override { AffineMap map = mapAttr.getValue(); @@ -1270,9 +1276,14 @@ protected: void numberValueID(Value *value); void numberValuesInRegion(Region ®ion); void numberValuesInBlock(Block &block); - void printValueID(Value *value, bool printResultNo = true) const; + void printValueID(Value *value, bool printResultNo = true) const { + printValueIDImpl(value, printResultNo, os); + } private: + void printValueIDImpl(Value *value, bool printResultNo, + raw_ostream &stream) const; + /// Uniques the given value name within the printer. If the given name /// conflicts, it is automatically renamed. StringRef uniqueValueName(StringRef name); @@ -1491,7 +1502,8 @@ void OperationPrinter::print(Operation *op) { printTrailingLocation(op->getLoc()); } -void OperationPrinter::printValueID(Value *value, bool printResultNo) const { +void OperationPrinter::printValueIDImpl(Value *value, bool printResultNo, + raw_ostream &stream) const { int resultNo = -1; auto lookupValue = value; @@ -1507,21 +1519,56 @@ void OperationPrinter::printValueID(Value *value, bool printResultNo) const { auto it = valueIDs.find(lookupValue); if (it == valueIDs.end()) { - os << "<>"; + stream << "<>"; return; } - os << '%'; + stream << '%'; if (it->second != nameSentinel) { - os << it->second; + stream << it->second; } else { auto nameIt = valueNames.find(lookupValue); assert(nameIt != valueNames.end() && "Didn't have a name entry?"); - os << nameIt->second; + stream << nameIt->second; } if (resultNo != -1 && printResultNo) - os << '#' << resultNo; + stream << '#' << resultNo; +} + +/// Renumber the arguments for the specified region to the same names as the +/// SSA values in namesToUse. This may only be used for IsolatedFromAbove +/// operations. If any entry in namesToUse is null, the corresponding +/// argument name is left alone. +void OperationPrinter::shadowRegionArgs(Region ®ion, + ArrayRef namesToUse) { + assert(!region.empty() && "cannot shadow arguments of an empty region"); + assert(region.front().getNumArguments() == namesToUse.size() && + "incorrect number of names passed in"); + assert(region.getParentOp()->isKnownIsolatedFromAbove() && + "only KnownIsolatedFromAbove ops can shadow names"); + + SmallVector nameStr; + for (unsigned i = 0, e = namesToUse.size(); i != e; ++i) { + auto *nameToUse = namesToUse[i]; + if (nameToUse == nullptr) + continue; + + auto *nameToReplace = region.front().getArgument(i); + + nameStr.clear(); + llvm::raw_svector_ostream nameStream(nameStr); + printValueIDImpl(nameToUse, /*printResultNo=*/true, nameStream); + + // Entry block arguments should already have a pretty "arg" name. + assert(valueIDs[nameToReplace] == nameSentinel); + + // Use the name without the leading %. + auto name = StringRef(nameStream.str()).drop_front(); + + // Overwrite the name. + valueNames[nameToReplace] = name.copy(usedNameAllocator); + } } void OperationPrinter::printOperation(Operation *op) { diff --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir index 6f576a837989..db4a096c6546 100644 --- a/mlir/test/IR/parser.mlir +++ b/mlir/test/IR/parser.mlir @@ -1055,13 +1055,25 @@ func @op_with_region_args() { // CHECK-LABEL: func @op_with_passthrough_region_args func @op_with_passthrough_region_args() { // CHECK: [[VAL:%.*]] = constant - // CHECK: "test.isolated_region"([[VAL]]) - // CHECK-NEXT: ^{{.*}}([[ARG:%.*]]: index) - // CHECK-NEXT: "foo.consumer"([[ARG]]) : (index) - %0 = constant 10 : index + + // CHECK: test.isolated_region [[VAL]] { + // CHECK-NEXT: "foo.consumer"([[VAL]]) : (index) + // CHECK-NEXT: } test.isolated_region %0 { "foo.consumer"(%0) : (index) -> () } + + // CHECK: [[VAL:%.*]]:2 = "foo.op" + %result:2 = "foo.op"() : () -> (index, index) + + // CHECK: test.isolated_region [[VAL]]#1 { + // CHECK-NEXT: "foo.consumer"([[VAL]]#1) : (index) + // CHECK-NEXT: } + test.isolated_region %result#1 { + "foo.consumer"(%result#1) : (index) -> () + } + return } + diff --git a/mlir/test/lib/TestDialect/TestDialect.cpp b/mlir/test/lib/TestDialect/TestDialect.cpp index 40faa0dccdf8..8b44b6cb548a 100644 --- a/mlir/test/lib/TestDialect/TestDialect.cpp +++ b/mlir/test/lib/TestDialect/TestDialect.cpp @@ -54,6 +54,13 @@ static ParseResult parseIsolatedRegionOp(OpAsmParser *parser, /*enableNameShadowing=*/true); } +static void print(OpAsmPrinter *p, IsolatedRegionOp op) { + *p << "test.isolated_region "; + p->printOperand(op.getOperand()); + p->shadowRegionArgs(op.region(), op.getOperand()); + p->printRegion(op.region(), /*printEntryBlockArgs=*/false); +} + //===----------------------------------------------------------------------===// // Test PolyForOp - parse list of region arguments. //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/TestDialect/TestOps.td b/mlir/test/lib/TestDialect/TestOps.td index 55466b734f11..29269309fd83 100644 --- a/mlir/test/lib/TestDialect/TestOps.td +++ b/mlir/test/lib/TestDialect/TestOps.td @@ -704,6 +704,7 @@ def IsolatedRegionOp : TEST_Op<"isolated_region", [IsolatedFromAbove]> { let arguments = (ins Index:$input); let regions = (region SizedRegion<1>:$region); let parser = [{ return ::parse$cppClass(parser, result); }]; + let printer = [{ return ::print(p, *this); }]; } def PolyForOp : TEST_Op<"polyfor">