[mlir] Print bbArgs of linalg.map/reduce/tranpose on the next line.

```
%mapped = linalg.map
  ins(%arg0 : tensor<64xf32>)
  outs(%arg1 : tensor<64xf32>)
  (%in: f32) {
    %0 = math.absf %in : f32
    linalg.yield %0 : f32
  }
%reduced = linalg.reduce
  ins(%arg0 : tensor<16x32x64xf32>)
  outs(%arg1 : tensor<16x64xf32>)
  dimensions = [1]
  (%in: f32, %init: f32) {
    %0 = arith.addf %in, %init : f32
    linalg.yield %0 : f32
  }
%transposed = linalg.transpose
  ins(%arg0 : tensor<16x32x64xf32>)
  outs(%arg1 : tensor<32x64x16xf32>)
  permutation = [1, 2, 0]
```

Differential Revision: https://reviews.llvm.org/D136818
This commit is contained in:
Alexander Belyaev 2022-10-27 09:39:52 +02:00
parent cfaf3292df
commit 350d686444
5 changed files with 66 additions and 12 deletions

View File

@ -334,6 +334,12 @@ public:
/// operation.
virtual void printNewline() = 0;
/// Increase indentation.
virtual void increaseIndent() = 0;
/// Decrease indentation.
virtual void decreaseIndent() = 0;
/// Print a block argument in the usual format of:
/// %ssaName : type {attr1=42} loc("here")
/// where location printing is controlled by the standard internal option.

View File

@ -173,6 +173,16 @@ static void printCommonStructuredOpParts(OpAsmPrinter &p, ValueRange inputs,
p << " outs(" << outputs << " : " << outputs.getTypes() << ")";
}
static void printCommonStructuredOpPartsWithNewLine(OpAsmPrinter &p,
ValueRange inputs,
ValueRange outputs) {
p.printNewline();
if (!inputs.empty())
p << "ins(" << inputs << " : " << inputs.getTypes() << ")";
p.printNewline();
if (!outputs.empty())
p << "outs(" << outputs << " : " << outputs.getTypes() << ")";
}
//===----------------------------------------------------------------------===//
// Specific parsing and printing for named structured ops created by ods-gen.
//===----------------------------------------------------------------------===//
@ -1335,16 +1345,20 @@ ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) {
}
void MapOp::print(OpAsmPrinter &p) {
printCommonStructuredOpParts(p, SmallVector<Value>(getInputOperands()),
SmallVector<Value>(getOutputOperands()));
p.increaseIndent();
printCommonStructuredOpPartsWithNewLine(
p, SmallVector<Value>(getInputOperands()),
SmallVector<Value>(getOutputOperands()));
p.printOptionalAttrDict((*this)->getAttrs());
p.printNewline();
p << "(";
llvm::interleaveComma(getMapper().getArguments(), p,
[&](auto arg) { p.printRegionArgument(arg); });
p << ") ";
p.printRegion(getMapper(), /*printEntryBlockArgs=*/false);
p.decreaseIndent();
}
LogicalResult MapOp::verify() {
@ -1481,21 +1495,26 @@ ParseResult ReduceOp::parse(OpAsmParser &parser, OperationState &result) {
static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName,
ArrayRef<int64_t> attributeValue) {
p << " " << attributeName << " = [" << attributeValue << "] ";
p << attributeName << " = [" << attributeValue << "] ";
}
void ReduceOp::print(OpAsmPrinter &p) {
printCommonStructuredOpParts(p, SmallVector<Value>(getInputOperands()),
SmallVector<Value>(getOutputOperands()));
p.increaseIndent();
printCommonStructuredOpPartsWithNewLine(
p, SmallVector<Value>(getInputOperands()),
SmallVector<Value>(getOutputOperands()));
p.printNewline();
printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions());
p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()});
p.printNewline();
p << "(";
llvm::interleaveComma(getCombiner().getArguments(), p,
[&](auto arg) { p.printRegionArgument(arg); });
p << ") ";
p.printRegion(getCombiner(), /*printEntryBlockArgs=*/false);
p.decreaseIndent();
}
LogicalResult ReduceOp::verify() {
@ -1657,10 +1676,14 @@ void TransposeOp::getAsmResultNames(
}
void TransposeOp::print(OpAsmPrinter &p) {
printCommonStructuredOpParts(p, SmallVector<Value>(getInputOperands()),
SmallVector<Value>(getOutputOperands()));
p.increaseIndent();
printCommonStructuredOpPartsWithNewLine(
p, SmallVector<Value>(getInputOperands()),
SmallVector<Value>(getOutputOperands()));
p.printNewline();
printDenseI64ArrayAttr(p, getPermutationAttrName(), getPermutation());
p.printOptionalAttrDict((*this)->getAttrs(), {getPermutationAttrName()});
p.decreaseIndent();
}
LogicalResult TransposeOp::verify() {

View File

@ -716,6 +716,8 @@ private:
void printAffineMapOfSSAIds(AffineMapAttr, ValueRange) override {}
void printAffineExprOfSSAIds(AffineExpr, ValueRange, ValueRange) override {}
void printNewline() override {}
void increaseIndent() override {}
void decreaseIndent() override {}
void printOperand(Value) override {}
void printOperand(Value, raw_ostream &os) override {
// Users expect the output string to have at least the prefixed % to signal
@ -2768,6 +2770,12 @@ public:
os.indent(currentIndent);
}
/// Increase indentation.
void increaseIndent() override { currentIndent += indentWidth; }
/// Decrease indentation.
void decreaseIndent() override { currentIndent -= indentWidth; }
/// Print a block argument in the usual format of:
/// %ssaName : type {attr1=42} loc("here")
/// where location printing is controlled by the standard internal option.

View File

@ -341,7 +341,7 @@ func.func @op_is_reading_but_following_ops_are_not(
func.func @map_binary(%lhs: tensor<64xf32>, %rhs: tensor<64xf32>,
%init: tensor<64xf32>) -> tensor<64xf32> {
// CHECK: linalg.map
// CHECK-SAME: ins(%[[LHS]], %[[RHS]] : memref<64xf32
// CHECK-NEXT: ins(%[[LHS]], %[[RHS]] : memref<64xf32
%add = linalg.map
ins(%lhs, %rhs: tensor<64xf32>, tensor<64xf32>)
outs(%init:tensor<64xf32>)
@ -359,7 +359,7 @@ func.func @map_binary(%lhs: tensor<64xf32>, %rhs: tensor<64xf32>,
func.func @reduce(%input: tensor<16x32x64xf32>,
%init: tensor<16x64xf32>) -> tensor<16x64xf32> {
// CHECK: linalg.reduce
// CHECK-SAME: ins(%[[INPUT]] : memref<16x32x64xf32
// CHECK-NEXT: ins(%[[INPUT]] : memref<16x32x64xf32
%reduce = linalg.reduce
ins(%input:tensor<16x32x64xf32>)
outs(%init:tensor<16x64xf32>)
@ -378,7 +378,7 @@ func.func @reduce(%input: tensor<16x32x64xf32>,
func.func @transpose(%input: tensor<16x32x64xf32>,
%init: tensor<32x64x16xf32>) -> tensor<32x64x16xf32> {
// CHECK: linalg.transpose
// CHECK-SAME: ins(%[[ARG0]] : memref<16x32x64xf32
// CHECK-NEXT: ins(%[[ARG0]] : memref<16x32x64xf32
%transpose = linalg.transpose
ins(%input:tensor<16x32x64xf32>)
outs(%init:tensor<32x64x16xf32>)

View File

@ -338,7 +338,13 @@ func.func @map_binary(%lhs: tensor<64xf32>, %rhs: tensor<64xf32>,
func.return %add : tensor<64xf32>
}
// CHECK-LABEL: func @map_binary
// CHECK: linalg.map
// CHECK: linalg.map
// CHECK-NEXT: ins
// CHECK-NEXT: outs
// CHECK-NEXT: (%{{.*}}: f32, %{{.*}}: f32) {
// CHECK-NEXT: arith.addf
// CHECK-NEXT: linalg.yield
// CHECK-NEXT: }
// -----
@ -401,7 +407,14 @@ func.func @reduce(%input: tensor<16x32x64xf32>,
func.return %reduce : tensor<16x64xf32>
}
// CHECK-LABEL: func @reduce
// CHECK: linalg.reduce
// CHECK: linalg.reduce
// CHECK-NEXT: ins
// CHECK-NEXT: outs
// CHECK-NEXT: dimensions = [1]
// CHECK-NEXT: (%{{.*}}: f32, %{{.*}}: f32) {
// CHECK-NEXT: arith.addf
// CHECK-NEXT: linalg.yield
// CHECK-NEXT: }
// -----
@ -469,6 +482,10 @@ func.func @transpose(%input: tensor<16x32x64xf32>,
func.return %transpose : tensor<32x64x16xf32>
}
// CHECK-LABEL: func @transpose
// CHECK: linalg.transpose
// CHECK-NEXT: ins
// CHECK-NEXT: outs
// CHECK-NEXT: permutation
// -----