[mlir][linalg][bufferize] Add FuncOp bufferization pass
This passes bufferizes FuncOp bodies, but not FuncOp boundaries. Differential Revision: https://reviews.llvm.org/D114671
This commit is contained in:
parent
e7f53ec78f
commit
8a232632c5
|
@ -0,0 +1,69 @@
|
||||||
|
// RUN: mlir-opt %s -test-comprehensive-function-bufferize="allow-return-memref allow-unknown-ops" -split-input-file | FileCheck %s
|
||||||
|
|
||||||
|
// Run fuzzer with different seeds.
|
||||||
|
// RUN: mlir-opt %s -test-comprehensive-function-bufferize="test-analysis-only analysis-fuzzer-seed=23" -split-input-file -o /dev/null
|
||||||
|
// RUN: mlir-opt %s -test-comprehensive-function-bufferize="test-analysis-only analysis-fuzzer-seed=59" -split-input-file -o /dev/null
|
||||||
|
// RUN: mlir-opt %s -test-comprehensive-function-bufferize="test-analysis-only analysis-fuzzer-seed=91" -split-input-file -o /dev/null
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @use_tensor_func_arg(
|
||||||
|
// CHECK-SAME: %[[A:.*]]: tensor<?xf32>
|
||||||
|
func @use_tensor_func_arg(%A : tensor<?xf32>) -> (vector<4xf32>) {
|
||||||
|
%c0 = arith.constant 0 : index
|
||||||
|
%f0 = arith.constant 0.0 : f32
|
||||||
|
|
||||||
|
// CHECK: %[[A_memref:.*]] = bufferization.to_memref %[[A]]
|
||||||
|
// CHECK: %[[res:.*]] = vector.transfer_read %[[A_memref]]
|
||||||
|
%0 = vector.transfer_read %A[%c0], %f0 : tensor<?xf32>, vector<4xf32>
|
||||||
|
|
||||||
|
// CHECK: return %[[res]]
|
||||||
|
return %0 : vector<4xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @return_tensor(
|
||||||
|
// CHECK-SAME: %[[A:.*]]: tensor<?xf32>
|
||||||
|
func @return_tensor(%A : tensor<?xf32>, %v : vector<4xf32>) -> (tensor<?xf32>) {
|
||||||
|
%c0 = arith.constant 0 : index
|
||||||
|
|
||||||
|
// CHECK: %[[A_memref:.*]] = bufferization.to_memref %[[A]]
|
||||||
|
// CHECK: %[[dim:.*]] = tensor.dim %[[A]]
|
||||||
|
// CHECK: %[[alloc:.*]] = memref.alloc(%[[dim]])
|
||||||
|
// CHECK: %[[casted:.*]] = memref.cast %[[alloc]]
|
||||||
|
// CHECK: memref.copy %[[A_memref]], %[[casted]]
|
||||||
|
// CHECK: vector.transfer_write %{{.*}}, %[[alloc]]
|
||||||
|
%0 = vector.transfer_write %v, %A[%c0] : vector<4xf32>, tensor<?xf32>
|
||||||
|
|
||||||
|
// CHECK: %[[res_tensor:.*]] = bufferization.to_tensor %[[casted]]
|
||||||
|
// CHECK: return %[[res_tensor]]
|
||||||
|
return %0 : tensor<?xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @func_without_tensor_args
|
||||||
|
func @func_without_tensor_args(%v : vector<10xf32>) -> () {
|
||||||
|
// CHECK: %[[alloc:.*]] = memref.alloc()
|
||||||
|
%0 = linalg.init_tensor[10] : tensor<10xf32>
|
||||||
|
|
||||||
|
%c0 = arith.constant 0 : index
|
||||||
|
// CHECK: vector.transfer_write %{{.*}}, %[[alloc]]
|
||||||
|
%1 = vector.transfer_write %v, %0[%c0] : vector<10xf32>, tensor<10xf32>
|
||||||
|
|
||||||
|
%cst = arith.constant 0.0 : f32
|
||||||
|
// CHECK: vector.transfer_read %[[alloc]]
|
||||||
|
%r = vector.transfer_read %1[%c0], %cst : tensor<10xf32>, vector<11xf32>
|
||||||
|
|
||||||
|
vector.print %r : vector<11xf32>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func private @private_func
|
||||||
|
func private @private_func(tensor<?xf32>) -> ()
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @empty_func()
|
||||||
|
func @empty_func() -> () {
|
||||||
|
return
|
||||||
|
}
|
|
@ -979,3 +979,32 @@ func @equivalent_func_arg_2(%t0: tensor<?xf32> {linalg.inplaceable = true},
|
||||||
}
|
}
|
||||||
return %1: tensor<?xf32>
|
return %1: tensor<?xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @func_without_tensor_args
|
||||||
|
func @func_without_tensor_args(%v : vector<10xf32>) -> () {
|
||||||
|
// CHECK: %[[alloc:.*]] = memref.alloc()
|
||||||
|
%0 = linalg.init_tensor[10] : tensor<10xf32>
|
||||||
|
|
||||||
|
%c0 = arith.constant 0 : index
|
||||||
|
// CHECK: vector.transfer_write %{{.*}}, %[[alloc]]
|
||||||
|
%1 = vector.transfer_write %v, %0[%c0] : vector<10xf32>, tensor<10xf32>
|
||||||
|
|
||||||
|
%cst = arith.constant 0.0 : f32
|
||||||
|
// CHECK: vector.transfer_read %[[alloc]]
|
||||||
|
%r = vector.transfer_read %1[%c0], %cst : tensor<10xf32>, vector<11xf32>
|
||||||
|
|
||||||
|
vector.print %r : vector<11xf32>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func private @private_func
|
||||||
|
func private @private_func(tensor<?xf32>) -> ()
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @empty_func()
|
||||||
|
func @empty_func() -> () {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
# Exclude tests from libMLIR.so
|
# Exclude tests from libMLIR.so
|
||||||
add_mlir_library(MLIRLinalgTestPasses
|
add_mlir_library(MLIRLinalgTestPasses
|
||||||
|
TestComprehensiveBufferize.cpp
|
||||||
TestConvVectorization.cpp
|
TestConvVectorization.cpp
|
||||||
TestLinalgCodegenStrategy.cpp
|
TestLinalgCodegenStrategy.cpp
|
||||||
TestLinalgDistribution.cpp
|
TestLinalgDistribution.cpp
|
||||||
|
@ -12,13 +13,25 @@ add_mlir_library(MLIRLinalgTestPasses
|
||||||
|
|
||||||
LINK_LIBS PUBLIC
|
LINK_LIBS PUBLIC
|
||||||
MLIRAffine
|
MLIRAffine
|
||||||
|
MLIRAffineBufferizableOpInterfaceImpl
|
||||||
|
MLIRArithBufferizableOpInterfaceImpl
|
||||||
|
MLIRArithmetic
|
||||||
|
MLIRBufferizableOpInterface
|
||||||
|
MLIRComprehensiveBufferize
|
||||||
MLIRGPUTransforms
|
MLIRGPUTransforms
|
||||||
MLIRLinalg
|
MLIRLinalg
|
||||||
|
MLIRLinalgBufferizableOpInterfaceImpl
|
||||||
MLIRLinalgTransforms
|
MLIRLinalgTransforms
|
||||||
MLIRLLVMToLLVMIRTranslation
|
MLIRLLVMToLLVMIRTranslation
|
||||||
|
MLIRMemRef
|
||||||
MLIRPass
|
MLIRPass
|
||||||
|
MLIRSCF
|
||||||
|
MLIRSCFBufferizableOpInterfaceImpl
|
||||||
MLIRStandard
|
MLIRStandard
|
||||||
|
MLIRTensor
|
||||||
|
MLIRTensorBufferizableOpInterfaceImpl
|
||||||
MLIRTransformUtils
|
MLIRTransformUtils
|
||||||
MLIRVector
|
MLIRVector
|
||||||
|
MLIRVectorBufferizableOpInterfaceImpl
|
||||||
MLIRVectorToSCF
|
MLIRVectorToSCF
|
||||||
)
|
)
|
||||||
|
|
|
@ -0,0 +1,124 @@
|
||||||
|
//===- TestComprehensiveBufferize.cpp - Test Comprehensive Bufferize ------===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
//
|
||||||
|
// This file implements logic for testing Comprehensive Bufferize.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||||
|
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
||||||
|
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
|
||||||
|
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/AffineInterfaceImpl.h"
|
||||||
|
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.h"
|
||||||
|
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h"
|
||||||
|
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.h"
|
||||||
|
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h"
|
||||||
|
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h"
|
||||||
|
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h"
|
||||||
|
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h"
|
||||||
|
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.h"
|
||||||
|
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.h"
|
||||||
|
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
|
||||||
|
#include "mlir/Dialect/Linalg/Passes.h"
|
||||||
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
|
#include "mlir/Dialect/Vector/VectorOps.h"
|
||||||
|
#include "mlir/Pass/PassManager.h"
|
||||||
|
#include "mlir/Transforms/Passes.h"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
using namespace mlir::linalg;
|
||||||
|
using namespace mlir::linalg::comprehensive_bufferize;
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
/// A helper struct for FunctionBufferize and ModuleBufferize. Both passes are
|
||||||
|
/// mostly identical.
|
||||||
|
struct TestComprehensiveFunctionBufferize
|
||||||
|
: public PassWrapper<TestComprehensiveFunctionBufferize, FunctionPass> {
|
||||||
|
StringRef getArgument() const final {
|
||||||
|
return "test-comprehensive-function-bufferize";
|
||||||
|
}
|
||||||
|
|
||||||
|
StringRef getDescription() const final {
|
||||||
|
return "Test Comprehensive Bufferize of FuncOps (body only).";
|
||||||
|
}
|
||||||
|
|
||||||
|
TestComprehensiveFunctionBufferize() = default;
|
||||||
|
TestComprehensiveFunctionBufferize(
|
||||||
|
const TestComprehensiveFunctionBufferize &pass) {}
|
||||||
|
|
||||||
|
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||||
|
registry.insert<bufferization::BufferizationDialect, linalg::LinalgDialect,
|
||||||
|
memref::MemRefDialect, tensor::TensorDialect,
|
||||||
|
vector::VectorDialect, scf::SCFDialect,
|
||||||
|
arith::ArithmeticDialect, AffineDialect>();
|
||||||
|
affine_ext::registerBufferizableOpInterfaceExternalModels(registry);
|
||||||
|
arith_ext::registerBufferizableOpInterfaceExternalModels(registry);
|
||||||
|
bufferization_ext::registerBufferizableOpInterfaceExternalModels(registry);
|
||||||
|
linalg_ext::registerBufferizableOpInterfaceExternalModels(registry);
|
||||||
|
scf_ext::registerBufferizableOpInterfaceExternalModels(registry);
|
||||||
|
tensor_ext::registerBufferizableOpInterfaceExternalModels(registry);
|
||||||
|
vector_ext::registerBufferizableOpInterfaceExternalModels(registry);
|
||||||
|
}
|
||||||
|
|
||||||
|
void runOnFunction() override;
|
||||||
|
|
||||||
|
Option<bool> allowReturnMemref{
|
||||||
|
*this, "allow-return-memref",
|
||||||
|
llvm::cl::desc("Allow returning/yielding memrefs from functions/blocks"),
|
||||||
|
llvm::cl::init(false)};
|
||||||
|
Option<bool> allowUnknownOps{
|
||||||
|
*this, "allow-unknown-ops",
|
||||||
|
llvm::cl::desc(
|
||||||
|
"Allows the return of memrefs (for testing purposes only)"),
|
||||||
|
llvm::cl::init(false)};
|
||||||
|
Option<bool> testAnalysisOnly{
|
||||||
|
*this, "test-analysis-only",
|
||||||
|
llvm::cl::desc(
|
||||||
|
"Only runs inplaceability analysis (for testing purposes only)"),
|
||||||
|
llvm::cl::init(false)};
|
||||||
|
Option<unsigned> analysisFuzzerSeed{
|
||||||
|
*this, "analysis-fuzzer-seed",
|
||||||
|
llvm::cl::desc("Analyze ops in random order with a given seed (fuzzer)"),
|
||||||
|
llvm::cl::init(0)};
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void TestComprehensiveFunctionBufferize::runOnFunction() {
|
||||||
|
BufferizationOptions options;
|
||||||
|
|
||||||
|
// Enable InitTensorOp elimination.
|
||||||
|
options.addPostAnalysisStep<
|
||||||
|
linalg_ext::InsertSliceAnchoredInitTensorEliminationStep>();
|
||||||
|
// TODO: Find a way to enable this step automatically when bufferizing
|
||||||
|
// tensor dialect ops.
|
||||||
|
options.addPostAnalysisStep<tensor_ext::InplaceInsertSliceOpAnalysis>();
|
||||||
|
options.addPostAnalysisStep<scf_ext::AssertDestinationPassingStyle>();
|
||||||
|
|
||||||
|
options.allowReturnMemref = allowReturnMemref;
|
||||||
|
options.allowUnknownOps = allowUnknownOps;
|
||||||
|
options.testAnalysisOnly = testAnalysisOnly;
|
||||||
|
options.analysisFuzzerSeed = analysisFuzzerSeed;
|
||||||
|
|
||||||
|
Operation *op = getFunction().getOperation();
|
||||||
|
if (failed(runComprehensiveBufferize(op, options)))
|
||||||
|
return;
|
||||||
|
|
||||||
|
OpPassManager cleanupPipeline("builtin.func");
|
||||||
|
cleanupPipeline.addPass(createCanonicalizerPass());
|
||||||
|
cleanupPipeline.addPass(createCSEPass());
|
||||||
|
cleanupPipeline.addPass(createLoopInvariantCodeMotionPass());
|
||||||
|
(void)this->runPipeline(cleanupPipeline, op);
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace test {
|
||||||
|
void registerTestComprehensiveFunctionBufferize() {
|
||||||
|
PassRegistration<TestComprehensiveFunctionBufferize>();
|
||||||
|
}
|
||||||
|
} // namespace test
|
||||||
|
} // namespace mlir
|
|
@ -64,6 +64,7 @@ void registerTestAffineLoopParametricTilingPass();
|
||||||
void registerTestAliasAnalysisPass();
|
void registerTestAliasAnalysisPass();
|
||||||
void registerTestBuiltinAttributeInterfaces();
|
void registerTestBuiltinAttributeInterfaces();
|
||||||
void registerTestCallGraphPass();
|
void registerTestCallGraphPass();
|
||||||
|
void registerTestComprehensiveFunctionBufferize();
|
||||||
void registerTestConstantFold();
|
void registerTestConstantFold();
|
||||||
void registerTestConvVectorization();
|
void registerTestConvVectorization();
|
||||||
void registerTestGpuSerializeToCubinPass();
|
void registerTestGpuSerializeToCubinPass();
|
||||||
|
@ -159,6 +160,7 @@ void registerTestPasses() {
|
||||||
#if MLIR_ROCM_CONVERSIONS_ENABLED
|
#if MLIR_ROCM_CONVERSIONS_ENABLED
|
||||||
mlir::test::registerTestGpuSerializeToHsacoPass();
|
mlir::test::registerTestGpuSerializeToHsacoPass();
|
||||||
#endif
|
#endif
|
||||||
|
mlir::test::registerTestComprehensiveFunctionBufferize();
|
||||||
mlir::test::registerTestConvVectorization();
|
mlir::test::registerTestConvVectorization();
|
||||||
mlir::test::registerTestDecomposeCallGraphTypes();
|
mlir::test::registerTestDecomposeCallGraphTypes();
|
||||||
mlir::test::registerTestDataLayoutQuery();
|
mlir::test::registerTestDataLayoutQuery();
|
||||||
|
|
|
@ -381,15 +381,27 @@ cc_library(
|
||||||
deps = [
|
deps = [
|
||||||
"//llvm:Support",
|
"//llvm:Support",
|
||||||
"//mlir:Affine",
|
"//mlir:Affine",
|
||||||
|
"//mlir:AffineBufferizableOpInterfaceImpl",
|
||||||
|
"//mlir:ArithBufferizableOpInterfaceImpl",
|
||||||
"//mlir:ArithmeticDialect",
|
"//mlir:ArithmeticDialect",
|
||||||
|
"//mlir:BufferizableOpInterface",
|
||||||
|
"//mlir:BufferizationDialect",
|
||||||
|
"//mlir:ComprehensiveBufferize",
|
||||||
"//mlir:GPUDialect",
|
"//mlir:GPUDialect",
|
||||||
"//mlir:IR",
|
"//mlir:IR",
|
||||||
|
"//mlir:LinalgBufferizableOpInterfaceImpl",
|
||||||
"//mlir:LinalgOps",
|
"//mlir:LinalgOps",
|
||||||
"//mlir:LinalgTransforms",
|
"//mlir:LinalgTransforms",
|
||||||
|
"//mlir:MemRefDialect",
|
||||||
"//mlir:Pass",
|
"//mlir:Pass",
|
||||||
|
"//mlir:SCFBufferizableOpInterfaceImpl",
|
||||||
|
"//mlir:SCFDialect",
|
||||||
"//mlir:SCFTransforms",
|
"//mlir:SCFTransforms",
|
||||||
"//mlir:StandardOps",
|
"//mlir:StandardOps",
|
||||||
|
"//mlir:TensorBufferizableOpInterfaceImpl",
|
||||||
|
"//mlir:TensorDialect",
|
||||||
"//mlir:TransformUtils",
|
"//mlir:TransformUtils",
|
||||||
|
"//mlir:VectorBufferizableOpInterfaceImpl",
|
||||||
"//mlir:VectorOps",
|
"//mlir:VectorOps",
|
||||||
"//mlir:VectorToSCF",
|
"//mlir:VectorToSCF",
|
||||||
],
|
],
|
||||||
|
|
Loading…
Reference in New Issue