diff --git a/mlir/lib/Analysis/BufferViewFlowAnalysis.cpp b/mlir/lib/Analysis/BufferViewFlowAnalysis.cpp index 5b2b31db2949..80f538dfb3bc 100644 --- a/mlir/lib/Analysis/BufferViewFlowAnalysis.cpp +++ b/mlir/lib/Analysis/BufferViewFlowAnalysis.cpp @@ -8,6 +8,7 @@ #include "mlir/Analysis/BufferViewFlowAnalysis.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Interfaces/ViewLikeInterface.h" #include "llvm/ADT/SetOperations.h" @@ -51,9 +52,9 @@ void BufferViewFlowAnalysis::remove(const SmallPtrSetImpl &aliasValues) { /// successor regions and branch-like return operations from nested regions. void BufferViewFlowAnalysis::build(Operation *op) { // Registers all dependencies of the given values. - auto registerDependencies = [&](auto values, auto dependencies) { - for (auto entry : llvm::zip(values, dependencies)) - this->dependencies[std::get<0>(entry)].insert(std::get<1>(entry)); + auto registerDependencies = [&](ValueRange values, ValueRange dependencies) { + for (auto [value, dep] : llvm::zip(values, dependencies)) + this->dependencies[value].insert(dep); }; // Add additional dependencies created by view changes to the alias list. @@ -119,4 +120,10 @@ void BufferViewFlowAnalysis::build(Operation *op) { } } }); + + // TODO: This should be an interface. + op->walk([&](arith::SelectOp selectOp) { + registerDependencies({selectOp.getOperand(1)}, {selectOp.getResult()}); + registerDependencies({selectOp.getOperand(2)}, {selectOp.getResult()}); + }); } diff --git a/mlir/lib/Analysis/CMakeLists.txt b/mlir/lib/Analysis/CMakeLists.txt index 4ead4c712790..701584c53f36 100644 --- a/mlir/lib/Analysis/CMakeLists.txt +++ b/mlir/lib/Analysis/CMakeLists.txt @@ -40,6 +40,7 @@ add_mlir_library(MLIRAnalysis mlir-headers LINK_LIBS PUBLIC + MLIRArithmeticDialect MLIRCallInterfaces MLIRControlFlowInterfaces MLIRDataLayoutInterfaces diff --git a/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation.mlir b/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation.mlir index 240cc2a60681..61493d964093 100644 --- a/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation.mlir @@ -1298,3 +1298,19 @@ func.func @while_three_arg(%arg0: index) { // CHECK-NEXT: return return } + +// ----- + +func.func @select_aliases(%arg0: index, %arg1: memref, %arg2: i1) { + // CHECK: memref.alloc + // CHECK: memref.alloc + // CHECK: arith.select + // CHECK: test.copy + // CHECK: memref.dealloc + // CHECK: memref.dealloc + %0 = memref.alloc(%arg0) : memref + %1 = memref.alloc(%arg0) : memref + %2 = arith.select %arg2, %0, %1 : memref + test.copy(%2, %arg1) : (memref, memref) + return +} diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel index 7d45b2ef5cc4..cefa2a93afc1 100644 --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -5879,6 +5879,7 @@ cc_library( ), includes = ["include"], deps = [ + ":ArithmeticDialect", ":CallOpInterfaces", ":ControlFlowInterfaces", ":DataLayoutInterfaces",