[mlir][Tensor] Add folding for tensor.from_elements

This trivially folds into a constant when all operands are constant.

Differential Revision: https://reviews.llvm.org/D102199
This commit is contained in:
Benjamin Kramer 2021-05-10 23:19:59 +02:00
parent 79be9c59c6
commit 7b52aeadfa
4 changed files with 20 additions and 3 deletions

View File

@ -137,6 +137,7 @@ def Tensor_FromElementsOp : Tensor_Op<"from_elements", [
];
let hasCanonicalizer = 1;
let hasFolder = 1;
}
//===----------------------------------------------------------------------===//

View File

@ -238,6 +238,12 @@ void FromElementsOp::build(OpBuilder &builder, OperationState &result,
build(builder, result, elements.front().getType(), elements);
}
OpFoldResult FromElementsOp::fold(ArrayRef<Attribute> operands) {
if (!llvm::is_contained(operands, nullptr))
return DenseElementsAttr::get(getType(), operands);
return {};
}
namespace {
// Canonicalizes the pattern of the form

View File

@ -35,9 +35,7 @@ func @main(%farg0 : tensor<i32>) -> (tensor<i1>) attributes {} {
// DET-ALL-NEXT: }
// DET-CF-LABEL: func @main(%{{.*}}: tensor<i32>)
// DET-CF-NEXT: constant 10 : i32
// DET-CF-NEXT: tensor.from_elements %{{.*}}
// DET-CF-NEXT: linalg.tensor_reshape %{{.*}}
// DET-CF-NEXT: constant dense<10> : tensor<i32>
// DET-CF-NEXT: linalg.init_tensor [] : tensor<i1>
// DET-CF-NEXT: linalg.generic
// DET-CF-NEXT: ^{{.*}}(%{{.*}}: i32, %{{.*}}: i32, %{{.*}}: i1)

View File

@ -238,3 +238,15 @@ func @static_tensor.generate(%size1: index, %size4: index) -> tensor<3x?x?x7x?xi
// CHECK: tensor.cast %{{.*}} : tensor<3x?x5x7x?xindex> to tensor<3x?x?x7x?xindex>
return %0 : tensor<3x?x?x7x?xindex>
}
// -----
// CHECK-LABEL: @from_elements.constant
func @from_elements.constant() -> tensor<3xindex> {
// CHECK: %[[CST:.*]] = constant dense<[1, 2, 1]> : tensor<3xindex>
// CHECK: return %[[CST]]
%c1 = constant 1 : index
%c2 = constant 2 : index
%tensor = tensor.from_elements %c1, %c2, %c1 : tensor<3xindex>
return %tensor : tensor<3xindex>
}