From 350d68644445f53551df1a4ddd69bd4f54f09fff Mon Sep 17 00:00:00 2001 From: Alexander Belyaev Date: Thu, 27 Oct 2022 09:39:52 +0200 Subject: [PATCH] [mlir] Print bbArgs of linalg.map/reduce/tranpose on the next line. ``` %mapped = linalg.map ins(%arg0 : tensor<64xf32>) outs(%arg1 : tensor<64xf32>) (%in: f32) { %0 = math.absf %in : f32 linalg.yield %0 : f32 } %reduced = linalg.reduce ins(%arg0 : tensor<16x32x64xf32>) outs(%arg1 : tensor<16x64xf32>) dimensions = [1] (%in: f32, %init: f32) { %0 = arith.addf %in, %init : f32 linalg.yield %0 : f32 } %transposed = linalg.transpose ins(%arg0 : tensor<16x32x64xf32>) outs(%arg1 : tensor<32x64x16xf32>) permutation = [1, 2, 0] ``` Differential Revision: https://reviews.llvm.org/D136818 --- mlir/include/mlir/IR/OpImplementation.h | 6 +++ mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 37 +++++++++++++++---- mlir/lib/IR/AsmPrinter.cpp | 8 ++++ .../Dialect/Linalg/one-shot-bufferize.mlir | 6 +-- mlir/test/Dialect/Linalg/roundtrip.mlir | 21 ++++++++++- 5 files changed, 66 insertions(+), 12 deletions(-) diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h index 474e9955bdcc..524c72d239cf 100644 --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -334,6 +334,12 @@ public: /// operation. virtual void printNewline() = 0; + /// Increase indentation. + virtual void increaseIndent() = 0; + + /// Decrease indentation. + virtual void decreaseIndent() = 0; + /// Print a block argument in the usual format of: /// %ssaName : type {attr1=42} loc("here") /// where location printing is controlled by the standard internal option. diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index 5d6dd379b2e4..a7e1938e2418 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -173,6 +173,16 @@ static void printCommonStructuredOpParts(OpAsmPrinter &p, ValueRange inputs, p << " outs(" << outputs << " : " << outputs.getTypes() << ")"; } +static void printCommonStructuredOpPartsWithNewLine(OpAsmPrinter &p, + ValueRange inputs, + ValueRange outputs) { + p.printNewline(); + if (!inputs.empty()) + p << "ins(" << inputs << " : " << inputs.getTypes() << ")"; + p.printNewline(); + if (!outputs.empty()) + p << "outs(" << outputs << " : " << outputs.getTypes() << ")"; +} //===----------------------------------------------------------------------===// // Specific parsing and printing for named structured ops created by ods-gen. //===----------------------------------------------------------------------===// @@ -1335,16 +1345,20 @@ ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) { } void MapOp::print(OpAsmPrinter &p) { - printCommonStructuredOpParts(p, SmallVector(getInputOperands()), - SmallVector(getOutputOperands())); + p.increaseIndent(); + printCommonStructuredOpPartsWithNewLine( + p, SmallVector(getInputOperands()), + SmallVector(getOutputOperands())); p.printOptionalAttrDict((*this)->getAttrs()); + p.printNewline(); p << "("; llvm::interleaveComma(getMapper().getArguments(), p, [&](auto arg) { p.printRegionArgument(arg); }); p << ") "; p.printRegion(getMapper(), /*printEntryBlockArgs=*/false); + p.decreaseIndent(); } LogicalResult MapOp::verify() { @@ -1481,21 +1495,26 @@ ParseResult ReduceOp::parse(OpAsmParser &parser, OperationState &result) { static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName, ArrayRef attributeValue) { - p << " " << attributeName << " = [" << attributeValue << "] "; + p << attributeName << " = [" << attributeValue << "] "; } void ReduceOp::print(OpAsmPrinter &p) { - printCommonStructuredOpParts(p, SmallVector(getInputOperands()), - SmallVector(getOutputOperands())); + p.increaseIndent(); + printCommonStructuredOpPartsWithNewLine( + p, SmallVector(getInputOperands()), + SmallVector(getOutputOperands())); + p.printNewline(); printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions()); p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()}); + p.printNewline(); p << "("; llvm::interleaveComma(getCombiner().getArguments(), p, [&](auto arg) { p.printRegionArgument(arg); }); p << ") "; p.printRegion(getCombiner(), /*printEntryBlockArgs=*/false); + p.decreaseIndent(); } LogicalResult ReduceOp::verify() { @@ -1657,10 +1676,14 @@ void TransposeOp::getAsmResultNames( } void TransposeOp::print(OpAsmPrinter &p) { - printCommonStructuredOpParts(p, SmallVector(getInputOperands()), - SmallVector(getOutputOperands())); + p.increaseIndent(); + printCommonStructuredOpPartsWithNewLine( + p, SmallVector(getInputOperands()), + SmallVector(getOutputOperands())); + p.printNewline(); printDenseI64ArrayAttr(p, getPermutationAttrName(), getPermutation()); p.printOptionalAttrDict((*this)->getAttrs(), {getPermutationAttrName()}); + p.decreaseIndent(); } LogicalResult TransposeOp::verify() { diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index 676f8133a877..9a3d3e031dc3 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -716,6 +716,8 @@ private: void printAffineMapOfSSAIds(AffineMapAttr, ValueRange) override {} void printAffineExprOfSSAIds(AffineExpr, ValueRange, ValueRange) override {} void printNewline() override {} + void increaseIndent() override {} + void decreaseIndent() override {} void printOperand(Value) override {} void printOperand(Value, raw_ostream &os) override { // Users expect the output string to have at least the prefixed % to signal @@ -2768,6 +2770,12 @@ public: os.indent(currentIndent); } + /// Increase indentation. + void increaseIndent() override { currentIndent += indentWidth; } + + /// Decrease indentation. + void decreaseIndent() override { currentIndent -= indentWidth; } + /// Print a block argument in the usual format of: /// %ssaName : type {attr1=42} loc("here") /// where location printing is controlled by the standard internal option. diff --git a/mlir/test/Dialect/Linalg/one-shot-bufferize.mlir b/mlir/test/Dialect/Linalg/one-shot-bufferize.mlir index e71f566c307c..58dec2be2373 100644 --- a/mlir/test/Dialect/Linalg/one-shot-bufferize.mlir +++ b/mlir/test/Dialect/Linalg/one-shot-bufferize.mlir @@ -341,7 +341,7 @@ func.func @op_is_reading_but_following_ops_are_not( func.func @map_binary(%lhs: tensor<64xf32>, %rhs: tensor<64xf32>, %init: tensor<64xf32>) -> tensor<64xf32> { // CHECK: linalg.map - // CHECK-SAME: ins(%[[LHS]], %[[RHS]] : memref<64xf32 + // CHECK-NEXT: ins(%[[LHS]], %[[RHS]] : memref<64xf32 %add = linalg.map ins(%lhs, %rhs: tensor<64xf32>, tensor<64xf32>) outs(%init:tensor<64xf32>) @@ -359,7 +359,7 @@ func.func @map_binary(%lhs: tensor<64xf32>, %rhs: tensor<64xf32>, func.func @reduce(%input: tensor<16x32x64xf32>, %init: tensor<16x64xf32>) -> tensor<16x64xf32> { // CHECK: linalg.reduce - // CHECK-SAME: ins(%[[INPUT]] : memref<16x32x64xf32 + // CHECK-NEXT: ins(%[[INPUT]] : memref<16x32x64xf32 %reduce = linalg.reduce ins(%input:tensor<16x32x64xf32>) outs(%init:tensor<16x64xf32>) @@ -378,7 +378,7 @@ func.func @reduce(%input: tensor<16x32x64xf32>, func.func @transpose(%input: tensor<16x32x64xf32>, %init: tensor<32x64x16xf32>) -> tensor<32x64x16xf32> { // CHECK: linalg.transpose - // CHECK-SAME: ins(%[[ARG0]] : memref<16x32x64xf32 + // CHECK-NEXT: ins(%[[ARG0]] : memref<16x32x64xf32 %transpose = linalg.transpose ins(%input:tensor<16x32x64xf32>) outs(%init:tensor<32x64x16xf32>) diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir index 4bea3f6d3837..6e1c26634c41 100644 --- a/mlir/test/Dialect/Linalg/roundtrip.mlir +++ b/mlir/test/Dialect/Linalg/roundtrip.mlir @@ -338,7 +338,13 @@ func.func @map_binary(%lhs: tensor<64xf32>, %rhs: tensor<64xf32>, func.return %add : tensor<64xf32> } // CHECK-LABEL: func @map_binary -// CHECK: linalg.map +// CHECK: linalg.map +// CHECK-NEXT: ins +// CHECK-NEXT: outs +// CHECK-NEXT: (%{{.*}}: f32, %{{.*}}: f32) { +// CHECK-NEXT: arith.addf +// CHECK-NEXT: linalg.yield +// CHECK-NEXT: } // ----- @@ -401,7 +407,14 @@ func.func @reduce(%input: tensor<16x32x64xf32>, func.return %reduce : tensor<16x64xf32> } // CHECK-LABEL: func @reduce -// CHECK: linalg.reduce +// CHECK: linalg.reduce +// CHECK-NEXT: ins +// CHECK-NEXT: outs +// CHECK-NEXT: dimensions = [1] +// CHECK-NEXT: (%{{.*}}: f32, %{{.*}}: f32) { +// CHECK-NEXT: arith.addf +// CHECK-NEXT: linalg.yield +// CHECK-NEXT: } // ----- @@ -469,6 +482,10 @@ func.func @transpose(%input: tensor<16x32x64xf32>, func.return %transpose : tensor<32x64x16xf32> } // CHECK-LABEL: func @transpose +// CHECK: linalg.transpose +// CHECK-NEXT: ins +// CHECK-NEXT: outs +// CHECK-NEXT: permutation // -----