[mlir] Print bbArgs of linalg.map/reduce/tranpose on the next line.
``` %mapped = linalg.map ins(%arg0 : tensor<64xf32>) outs(%arg1 : tensor<64xf32>) (%in: f32) { %0 = math.absf %in : f32 linalg.yield %0 : f32 } %reduced = linalg.reduce ins(%arg0 : tensor<16x32x64xf32>) outs(%arg1 : tensor<16x64xf32>) dimensions = [1] (%in: f32, %init: f32) { %0 = arith.addf %in, %init : f32 linalg.yield %0 : f32 } %transposed = linalg.transpose ins(%arg0 : tensor<16x32x64xf32>) outs(%arg1 : tensor<32x64x16xf32>) permutation = [1, 2, 0] ``` Differential Revision: https://reviews.llvm.org/D136818
This commit is contained in:
parent
cfaf3292df
commit
350d686444
|
@ -334,6 +334,12 @@ public:
|
|||
/// operation.
|
||||
virtual void printNewline() = 0;
|
||||
|
||||
/// Increase indentation.
|
||||
virtual void increaseIndent() = 0;
|
||||
|
||||
/// Decrease indentation.
|
||||
virtual void decreaseIndent() = 0;
|
||||
|
||||
/// Print a block argument in the usual format of:
|
||||
/// %ssaName : type {attr1=42} loc("here")
|
||||
/// where location printing is controlled by the standard internal option.
|
||||
|
|
|
@ -173,6 +173,16 @@ static void printCommonStructuredOpParts(OpAsmPrinter &p, ValueRange inputs,
|
|||
p << " outs(" << outputs << " : " << outputs.getTypes() << ")";
|
||||
}
|
||||
|
||||
static void printCommonStructuredOpPartsWithNewLine(OpAsmPrinter &p,
|
||||
ValueRange inputs,
|
||||
ValueRange outputs) {
|
||||
p.printNewline();
|
||||
if (!inputs.empty())
|
||||
p << "ins(" << inputs << " : " << inputs.getTypes() << ")";
|
||||
p.printNewline();
|
||||
if (!outputs.empty())
|
||||
p << "outs(" << outputs << " : " << outputs.getTypes() << ")";
|
||||
}
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Specific parsing and printing for named structured ops created by ods-gen.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -1335,16 +1345,20 @@ ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) {
|
|||
}
|
||||
|
||||
void MapOp::print(OpAsmPrinter &p) {
|
||||
printCommonStructuredOpParts(p, SmallVector<Value>(getInputOperands()),
|
||||
SmallVector<Value>(getOutputOperands()));
|
||||
p.increaseIndent();
|
||||
printCommonStructuredOpPartsWithNewLine(
|
||||
p, SmallVector<Value>(getInputOperands()),
|
||||
SmallVector<Value>(getOutputOperands()));
|
||||
p.printOptionalAttrDict((*this)->getAttrs());
|
||||
|
||||
p.printNewline();
|
||||
p << "(";
|
||||
llvm::interleaveComma(getMapper().getArguments(), p,
|
||||
[&](auto arg) { p.printRegionArgument(arg); });
|
||||
p << ") ";
|
||||
|
||||
p.printRegion(getMapper(), /*printEntryBlockArgs=*/false);
|
||||
p.decreaseIndent();
|
||||
}
|
||||
|
||||
LogicalResult MapOp::verify() {
|
||||
|
@ -1481,21 +1495,26 @@ ParseResult ReduceOp::parse(OpAsmParser &parser, OperationState &result) {
|
|||
|
||||
static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName,
|
||||
ArrayRef<int64_t> attributeValue) {
|
||||
p << " " << attributeName << " = [" << attributeValue << "] ";
|
||||
p << attributeName << " = [" << attributeValue << "] ";
|
||||
}
|
||||
|
||||
void ReduceOp::print(OpAsmPrinter &p) {
|
||||
printCommonStructuredOpParts(p, SmallVector<Value>(getInputOperands()),
|
||||
SmallVector<Value>(getOutputOperands()));
|
||||
p.increaseIndent();
|
||||
printCommonStructuredOpPartsWithNewLine(
|
||||
p, SmallVector<Value>(getInputOperands()),
|
||||
SmallVector<Value>(getOutputOperands()));
|
||||
p.printNewline();
|
||||
printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions());
|
||||
p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()});
|
||||
|
||||
p.printNewline();
|
||||
p << "(";
|
||||
llvm::interleaveComma(getCombiner().getArguments(), p,
|
||||
[&](auto arg) { p.printRegionArgument(arg); });
|
||||
p << ") ";
|
||||
|
||||
p.printRegion(getCombiner(), /*printEntryBlockArgs=*/false);
|
||||
p.decreaseIndent();
|
||||
}
|
||||
|
||||
LogicalResult ReduceOp::verify() {
|
||||
|
@ -1657,10 +1676,14 @@ void TransposeOp::getAsmResultNames(
|
|||
}
|
||||
|
||||
void TransposeOp::print(OpAsmPrinter &p) {
|
||||
printCommonStructuredOpParts(p, SmallVector<Value>(getInputOperands()),
|
||||
SmallVector<Value>(getOutputOperands()));
|
||||
p.increaseIndent();
|
||||
printCommonStructuredOpPartsWithNewLine(
|
||||
p, SmallVector<Value>(getInputOperands()),
|
||||
SmallVector<Value>(getOutputOperands()));
|
||||
p.printNewline();
|
||||
printDenseI64ArrayAttr(p, getPermutationAttrName(), getPermutation());
|
||||
p.printOptionalAttrDict((*this)->getAttrs(), {getPermutationAttrName()});
|
||||
p.decreaseIndent();
|
||||
}
|
||||
|
||||
LogicalResult TransposeOp::verify() {
|
||||
|
|
|
@ -716,6 +716,8 @@ private:
|
|||
void printAffineMapOfSSAIds(AffineMapAttr, ValueRange) override {}
|
||||
void printAffineExprOfSSAIds(AffineExpr, ValueRange, ValueRange) override {}
|
||||
void printNewline() override {}
|
||||
void increaseIndent() override {}
|
||||
void decreaseIndent() override {}
|
||||
void printOperand(Value) override {}
|
||||
void printOperand(Value, raw_ostream &os) override {
|
||||
// Users expect the output string to have at least the prefixed % to signal
|
||||
|
@ -2768,6 +2770,12 @@ public:
|
|||
os.indent(currentIndent);
|
||||
}
|
||||
|
||||
/// Increase indentation.
|
||||
void increaseIndent() override { currentIndent += indentWidth; }
|
||||
|
||||
/// Decrease indentation.
|
||||
void decreaseIndent() override { currentIndent -= indentWidth; }
|
||||
|
||||
/// Print a block argument in the usual format of:
|
||||
/// %ssaName : type {attr1=42} loc("here")
|
||||
/// where location printing is controlled by the standard internal option.
|
||||
|
|
|
@ -341,7 +341,7 @@ func.func @op_is_reading_but_following_ops_are_not(
|
|||
func.func @map_binary(%lhs: tensor<64xf32>, %rhs: tensor<64xf32>,
|
||||
%init: tensor<64xf32>) -> tensor<64xf32> {
|
||||
// CHECK: linalg.map
|
||||
// CHECK-SAME: ins(%[[LHS]], %[[RHS]] : memref<64xf32
|
||||
// CHECK-NEXT: ins(%[[LHS]], %[[RHS]] : memref<64xf32
|
||||
%add = linalg.map
|
||||
ins(%lhs, %rhs: tensor<64xf32>, tensor<64xf32>)
|
||||
outs(%init:tensor<64xf32>)
|
||||
|
@ -359,7 +359,7 @@ func.func @map_binary(%lhs: tensor<64xf32>, %rhs: tensor<64xf32>,
|
|||
func.func @reduce(%input: tensor<16x32x64xf32>,
|
||||
%init: tensor<16x64xf32>) -> tensor<16x64xf32> {
|
||||
// CHECK: linalg.reduce
|
||||
// CHECK-SAME: ins(%[[INPUT]] : memref<16x32x64xf32
|
||||
// CHECK-NEXT: ins(%[[INPUT]] : memref<16x32x64xf32
|
||||
%reduce = linalg.reduce
|
||||
ins(%input:tensor<16x32x64xf32>)
|
||||
outs(%init:tensor<16x64xf32>)
|
||||
|
@ -378,7 +378,7 @@ func.func @reduce(%input: tensor<16x32x64xf32>,
|
|||
func.func @transpose(%input: tensor<16x32x64xf32>,
|
||||
%init: tensor<32x64x16xf32>) -> tensor<32x64x16xf32> {
|
||||
// CHECK: linalg.transpose
|
||||
// CHECK-SAME: ins(%[[ARG0]] : memref<16x32x64xf32
|
||||
// CHECK-NEXT: ins(%[[ARG0]] : memref<16x32x64xf32
|
||||
%transpose = linalg.transpose
|
||||
ins(%input:tensor<16x32x64xf32>)
|
||||
outs(%init:tensor<32x64x16xf32>)
|
||||
|
|
|
@ -338,7 +338,13 @@ func.func @map_binary(%lhs: tensor<64xf32>, %rhs: tensor<64xf32>,
|
|||
func.return %add : tensor<64xf32>
|
||||
}
|
||||
// CHECK-LABEL: func @map_binary
|
||||
// CHECK: linalg.map
|
||||
// CHECK: linalg.map
|
||||
// CHECK-NEXT: ins
|
||||
// CHECK-NEXT: outs
|
||||
// CHECK-NEXT: (%{{.*}}: f32, %{{.*}}: f32) {
|
||||
// CHECK-NEXT: arith.addf
|
||||
// CHECK-NEXT: linalg.yield
|
||||
// CHECK-NEXT: }
|
||||
|
||||
// -----
|
||||
|
||||
|
@ -401,7 +407,14 @@ func.func @reduce(%input: tensor<16x32x64xf32>,
|
|||
func.return %reduce : tensor<16x64xf32>
|
||||
}
|
||||
// CHECK-LABEL: func @reduce
|
||||
// CHECK: linalg.reduce
|
||||
// CHECK: linalg.reduce
|
||||
// CHECK-NEXT: ins
|
||||
// CHECK-NEXT: outs
|
||||
// CHECK-NEXT: dimensions = [1]
|
||||
// CHECK-NEXT: (%{{.*}}: f32, %{{.*}}: f32) {
|
||||
// CHECK-NEXT: arith.addf
|
||||
// CHECK-NEXT: linalg.yield
|
||||
// CHECK-NEXT: }
|
||||
|
||||
// -----
|
||||
|
||||
|
@ -469,6 +482,10 @@ func.func @transpose(%input: tensor<16x32x64xf32>,
|
|||
func.return %transpose : tensor<32x64x16xf32>
|
||||
}
|
||||
// CHECK-LABEL: func @transpose
|
||||
// CHECK: linalg.transpose
|
||||
// CHECK-NEXT: ins
|
||||
// CHECK-NEXT: outs
|
||||
// CHECK-NEXT: permutation
|
||||
|
||||
// -----
|
||||
|
||||
|
|
Loading…
Reference in New Issue