[mlir][Tensor] Add folding for tensor.from_elements
This trivially folds into a constant when all operands are constant. Differential Revision: https://reviews.llvm.org/D102199
This commit is contained in:
parent
79be9c59c6
commit
7b52aeadfa
|
@ -137,6 +137,7 @@ def Tensor_FromElementsOp : Tensor_Op<"from_elements", [
|
|||
];
|
||||
|
||||
let hasCanonicalizer = 1;
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -238,6 +238,12 @@ void FromElementsOp::build(OpBuilder &builder, OperationState &result,
|
|||
build(builder, result, elements.front().getType(), elements);
|
||||
}
|
||||
|
||||
OpFoldResult FromElementsOp::fold(ArrayRef<Attribute> operands) {
|
||||
if (!llvm::is_contained(operands, nullptr))
|
||||
return DenseElementsAttr::get(getType(), operands);
|
||||
return {};
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
// Canonicalizes the pattern of the form
|
||||
|
|
|
@ -35,9 +35,7 @@ func @main(%farg0 : tensor<i32>) -> (tensor<i1>) attributes {} {
|
|||
// DET-ALL-NEXT: }
|
||||
|
||||
// DET-CF-LABEL: func @main(%{{.*}}: tensor<i32>)
|
||||
// DET-CF-NEXT: constant 10 : i32
|
||||
// DET-CF-NEXT: tensor.from_elements %{{.*}}
|
||||
// DET-CF-NEXT: linalg.tensor_reshape %{{.*}}
|
||||
// DET-CF-NEXT: constant dense<10> : tensor<i32>
|
||||
// DET-CF-NEXT: linalg.init_tensor [] : tensor<i1>
|
||||
// DET-CF-NEXT: linalg.generic
|
||||
// DET-CF-NEXT: ^{{.*}}(%{{.*}}: i32, %{{.*}}: i32, %{{.*}}: i1)
|
||||
|
|
|
@ -238,3 +238,15 @@ func @static_tensor.generate(%size1: index, %size4: index) -> tensor<3x?x?x7x?xi
|
|||
// CHECK: tensor.cast %{{.*}} : tensor<3x?x5x7x?xindex> to tensor<3x?x?x7x?xindex>
|
||||
return %0 : tensor<3x?x?x7x?xindex>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @from_elements.constant
|
||||
func @from_elements.constant() -> tensor<3xindex> {
|
||||
// CHECK: %[[CST:.*]] = constant dense<[1, 2, 1]> : tensor<3xindex>
|
||||
// CHECK: return %[[CST]]
|
||||
%c1 = constant 1 : index
|
||||
%c2 = constant 2 : index
|
||||
%tensor = tensor.from_elements %c1, %c2, %c1 : tensor<3xindex>
|
||||
return %tensor : tensor<3xindex>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue