[mlir][Pass][NFC] Replace usages of ModulePass with OperationPass<ModuleOp>
ModulePass doesn't provide any special utilities and thus doesn't give enough benefit to warrant a special pass class. This revision replaces all usages with the more general OperationPass. Differential Revision: https://reviews.llvm.org/D77339
This commit is contained in:
parent
2481f26ac3
commit
722f909f7a
|
@ -105,7 +105,7 @@ We want to completely lower to LLVM, so we use a `FullConversion`. This ensures
|
|||
that only legal operations will remain after the conversion.
|
||||
|
||||
```c++
|
||||
mlir::ModuleOp module = getModule();
|
||||
mlir::ModuleOp module = getOperation();
|
||||
if (mlir::failed(mlir::applyFullConversion(module, target, patterns,
|
||||
&typeConverter)))
|
||||
signalPassFailure();
|
||||
|
|
|
@ -153,12 +153,13 @@ private:
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
struct ToyToLLVMLoweringPass : public ModulePass<ToyToLLVMLoweringPass> {
|
||||
void runOnModule() final;
|
||||
struct ToyToLLVMLoweringPass
|
||||
: public OperationPass<ToyToLLVMLoweringPass, ModuleOp> {
|
||||
void runOnOperation() final;
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
void ToyToLLVMLoweringPass::runOnModule() {
|
||||
void ToyToLLVMLoweringPass::runOnOperation() {
|
||||
// The first thing to define is the conversion target. This will define the
|
||||
// final target for this lowering. For this lowering, we are only targeting
|
||||
// the LLVM dialect.
|
||||
|
@ -191,7 +192,7 @@ void ToyToLLVMLoweringPass::runOnModule() {
|
|||
|
||||
// We want to completely lower to LLVM, so we use a `FullConversion`. This
|
||||
// ensures that only legal operations will remain after the conversion.
|
||||
auto module = getModule();
|
||||
auto module = getOperation();
|
||||
if (failed(applyFullConversion(module, target, patterns, &typeConverter)))
|
||||
signalPassFailure();
|
||||
}
|
||||
|
|
|
@ -153,12 +153,13 @@ private:
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
struct ToyToLLVMLoweringPass : public ModulePass<ToyToLLVMLoweringPass> {
|
||||
void runOnModule() final;
|
||||
struct ToyToLLVMLoweringPass
|
||||
: public OperationPass<ToyToLLVMLoweringPass, ModuleOp> {
|
||||
void runOnOperation() final;
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
void ToyToLLVMLoweringPass::runOnModule() {
|
||||
void ToyToLLVMLoweringPass::runOnOperation() {
|
||||
// The first thing to define is the conversion target. This will define the
|
||||
// final target for this lowering. For this lowering, we are only targeting
|
||||
// the LLVM dialect.
|
||||
|
@ -191,7 +192,7 @@ void ToyToLLVMLoweringPass::runOnModule() {
|
|||
|
||||
// We want to completely lower to LLVM, so we use a `FullConversion`. This
|
||||
// ensures that only legal operations will remain after the conversion.
|
||||
auto module = getModule();
|
||||
auto module = getOperation();
|
||||
if (failed(applyFullConversion(module, target, patterns, &typeConverter)))
|
||||
signalPassFailure();
|
||||
}
|
||||
|
|
|
@ -341,24 +341,9 @@ template <typename T> struct FunctionPass : public OperationPass<T, FuncOp> {
|
|||
runOnFunction();
|
||||
}
|
||||
|
||||
/// Return the current module being transformed.
|
||||
/// Return the current function being transformed.
|
||||
FuncOp getFunction() { return this->getOperation(); }
|
||||
};
|
||||
|
||||
/// A model for providing module pass specific utilities.
|
||||
///
|
||||
/// Derived module passes are expected to provide the following:
|
||||
/// - A 'void runOnModule()' method.
|
||||
template <typename T> struct ModulePass : public OperationPass<T, ModuleOp> {
|
||||
/// The polymorphic API that runs the pass over the currently held module.
|
||||
virtual void runOnModule() = 0;
|
||||
|
||||
/// The polymorphic API that runs the pass over the currently held operation.
|
||||
void runOnOperation() final { runOnModule(); }
|
||||
|
||||
/// Return the current module being transformed.
|
||||
ModuleOp getModule() { return this->getOperation(); }
|
||||
};
|
||||
} // end namespace mlir
|
||||
|
||||
#endif // MLIR_PASS_PASS_H
|
||||
|
|
|
@ -163,16 +163,17 @@ void mlir::populateAVX512ToLLVMConversionPatterns(
|
|||
}
|
||||
|
||||
namespace {
|
||||
struct ConvertAVX512ToLLVMPass : public ModulePass<ConvertAVX512ToLLVMPass> {
|
||||
struct ConvertAVX512ToLLVMPass
|
||||
: public OperationPass<ConvertAVX512ToLLVMPass, ModuleOp> {
|
||||
/// Include the generated pass utilities.
|
||||
#define GEN_PASS_ConvertAVX512ToLLVM
|
||||
#include "mlir/Conversion/Passes.h.inc"
|
||||
|
||||
void runOnModule() override;
|
||||
void runOnOperation() override;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void ConvertAVX512ToLLVMPass::runOnModule() {
|
||||
void ConvertAVX512ToLLVMPass::runOnOperation() {
|
||||
// Convert to the LLVM IR dialect.
|
||||
OwningRewritePatternList patterns;
|
||||
LLVMTypeConverter converter(&getContext());
|
||||
|
@ -186,8 +187,8 @@ void ConvertAVX512ToLLVMPass::runOnModule() {
|
|||
target.addIllegalDialect<avx512::AVX512Dialect>();
|
||||
target.addDynamicallyLegalOp<FuncOp>(
|
||||
[&](FuncOp op) { return converter.isSignatureLegal(op.getType()); });
|
||||
if (failed(
|
||||
applyPartialConversion(getModule(), target, patterns, &converter))) {
|
||||
if (failed(applyPartialConversion(getOperation(), target, patterns,
|
||||
&converter))) {
|
||||
signalPassFailure();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -61,7 +61,7 @@ namespace {
|
|||
///
|
||||
/// Intermediate data structures are allocated on the stack.
|
||||
class GpuLaunchFuncToCudaCallsPass
|
||||
: public ModulePass<GpuLaunchFuncToCudaCallsPass> {
|
||||
: public OperationPass<GpuLaunchFuncToCudaCallsPass, ModuleOp> {
|
||||
private:
|
||||
/// Include the generated pass utilities.
|
||||
#define GEN_PASS_ConvertGpuLaunchFuncToCudaCalls
|
||||
|
@ -126,20 +126,19 @@ private:
|
|||
|
||||
public:
|
||||
// Run the dialect converter on the module.
|
||||
void runOnModule() override {
|
||||
void runOnOperation() override {
|
||||
// Cache the LLVMDialect for the current module.
|
||||
llvmDialect = getContext().getRegisteredDialect<LLVM::LLVMDialect>();
|
||||
// Cache the used LLVM types.
|
||||
initializeCachedTypes();
|
||||
|
||||
getModule().walk([this](mlir::gpu::LaunchFuncOp op) {
|
||||
translateGpuLaunchCalls(op);
|
||||
});
|
||||
getOperation().walk(
|
||||
[this](mlir::gpu::LaunchFuncOp op) { translateGpuLaunchCalls(op); });
|
||||
|
||||
// GPU kernel modules are no longer necessary since we have a global
|
||||
// constant with the CUBIN data.
|
||||
for (auto m :
|
||||
llvm::make_early_inc_range(getModule().getOps<gpu::GPUModuleOp>()))
|
||||
llvm::make_early_inc_range(getOperation().getOps<gpu::GPUModuleOp>()))
|
||||
m.erase();
|
||||
}
|
||||
|
||||
|
@ -160,7 +159,7 @@ private:
|
|||
// The types in comments give the actual types expected/returned but the API
|
||||
// uses void pointers. This is fine as they have the same linkage in C.
|
||||
void GpuLaunchFuncToCudaCallsPass::declareCudaFunctions(Location loc) {
|
||||
ModuleOp module = getModule();
|
||||
ModuleOp module = getOperation();
|
||||
OpBuilder builder(module.getBody()->getTerminator());
|
||||
if (!module.lookupSymbol(cuModuleLoadName)) {
|
||||
builder.create<LLVM::LLVMFuncOp>(
|
||||
|
@ -391,7 +390,7 @@ void GpuLaunchFuncToCudaCallsPass::translateGpuLaunchCalls(
|
|||
builder.getI32IntegerAttr(0));
|
||||
// Create an LLVM global with CUBIN extracted from the kernel annotation and
|
||||
// obtain a pointer to the first byte in it.
|
||||
auto kernelModule = getModule().lookupSymbol<gpu::GPUModuleOp>(
|
||||
auto kernelModule = getOperation().lookupSymbol<gpu::GPUModuleOp>(
|
||||
launchOp.getKernelModuleName());
|
||||
assert(kernelModule && "expected a kernel module");
|
||||
|
||||
|
@ -412,7 +411,7 @@ void GpuLaunchFuncToCudaCallsPass::translateGpuLaunchCalls(
|
|||
// in the called helper function.
|
||||
auto cuModule = allocatePointer(builder, loc);
|
||||
auto cuModuleLoad =
|
||||
getModule().lookupSymbol<LLVM::LLVMFuncOp>(cuModuleLoadName);
|
||||
getOperation().lookupSymbol<LLVM::LLVMFuncOp>(cuModuleLoadName);
|
||||
builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getCUResultType()},
|
||||
builder.getSymbolRefAttr(cuModuleLoad),
|
||||
ArrayRef<Value>{cuModule, data});
|
||||
|
@ -423,20 +422,20 @@ void GpuLaunchFuncToCudaCallsPass::translateGpuLaunchCalls(
|
|||
auto kernelName = generateKernelNameConstant(launchOp.kernel(), loc, builder);
|
||||
auto cuFunction = allocatePointer(builder, loc);
|
||||
auto cuModuleGetFunction =
|
||||
getModule().lookupSymbol<LLVM::LLVMFuncOp>(cuModuleGetFunctionName);
|
||||
getOperation().lookupSymbol<LLVM::LLVMFuncOp>(cuModuleGetFunctionName);
|
||||
builder.create<LLVM::CallOp>(
|
||||
loc, ArrayRef<Type>{getCUResultType()},
|
||||
builder.getSymbolRefAttr(cuModuleGetFunction),
|
||||
ArrayRef<Value>{cuFunction, cuOwningModuleRef, kernelName});
|
||||
// Grab the global stream needed for execution.
|
||||
auto cuGetStreamHelper =
|
||||
getModule().lookupSymbol<LLVM::LLVMFuncOp>(cuGetStreamHelperName);
|
||||
getOperation().lookupSymbol<LLVM::LLVMFuncOp>(cuGetStreamHelperName);
|
||||
auto cuStream = builder.create<LLVM::CallOp>(
|
||||
loc, ArrayRef<Type>{getPointerType()},
|
||||
builder.getSymbolRefAttr(cuGetStreamHelper), ArrayRef<Value>{});
|
||||
// Invoke the function with required arguments.
|
||||
auto cuLaunchKernel =
|
||||
getModule().lookupSymbol<LLVM::LLVMFuncOp>(cuLaunchKernelName);
|
||||
getOperation().lookupSymbol<LLVM::LLVMFuncOp>(cuLaunchKernelName);
|
||||
auto cuFunctionRef =
|
||||
builder.create<LLVM::LoadOp>(loc, getPointerType(), cuFunction);
|
||||
auto paramsArray = setupParamsArray(launchOp, builder);
|
||||
|
@ -458,7 +457,7 @@ void GpuLaunchFuncToCudaCallsPass::translateGpuLaunchCalls(
|
|||
nullpointer /* extra */});
|
||||
// Sync on the stream to make it synchronous.
|
||||
auto cuStreamSync =
|
||||
getModule().lookupSymbol<LLVM::LLVMFuncOp>(cuStreamSynchronizeName);
|
||||
getOperation().lookupSymbol<LLVM::LLVMFuncOp>(cuStreamSynchronizeName);
|
||||
builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getCUResultType()},
|
||||
builder.getSymbolRefAttr(cuStreamSync),
|
||||
ArrayRef<Value>(cuStream.getResult(0)));
|
||||
|
|
|
@ -33,18 +33,18 @@ namespace {
|
|||
/// replace it).
|
||||
///
|
||||
/// 2) Lower the body of the spirv::ModuleOp.
|
||||
struct GPUToSPIRVPass : public ModulePass<GPUToSPIRVPass> {
|
||||
struct GPUToSPIRVPass : public OperationPass<GPUToSPIRVPass, ModuleOp> {
|
||||
/// Include the generated pass utilities.
|
||||
#define GEN_PASS_ConvertGpuToSPIRV
|
||||
#include "mlir/Conversion/Passes.h.inc"
|
||||
|
||||
void runOnModule() override;
|
||||
void runOnOperation() override;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void GPUToSPIRVPass::runOnModule() {
|
||||
void GPUToSPIRVPass::runOnOperation() {
|
||||
MLIRContext *context = &getContext();
|
||||
ModuleOp module = getModule();
|
||||
ModuleOp module = getOperation();
|
||||
|
||||
SmallVector<Operation *, 1> kernelModules;
|
||||
OpBuilder builder(context);
|
||||
|
|
|
@ -38,13 +38,13 @@ namespace {
|
|||
/// function and attaching binary data and entry point name as an attributes to
|
||||
/// created vulkan launch call op.
|
||||
class ConvertGpuLaunchFuncToVulkanLaunchFunc
|
||||
: public ModulePass<ConvertGpuLaunchFuncToVulkanLaunchFunc> {
|
||||
: public OperationPass<ConvertGpuLaunchFuncToVulkanLaunchFunc, ModuleOp> {
|
||||
public:
|
||||
/// Include the generated pass utilities.
|
||||
#define GEN_PASS_ConvertGpuLaunchFuncToVulkanLaunchFunc
|
||||
#include "mlir/Conversion/Passes.h.inc"
|
||||
|
||||
void runOnModule() override;
|
||||
void runOnOperation() override;
|
||||
|
||||
private:
|
||||
/// Creates a SPIR-V binary shader from the given `module` using
|
||||
|
@ -68,14 +68,13 @@ private:
|
|||
/// operand is unsupported by Vulkan runtime.
|
||||
LogicalResult declareVulkanLaunchFunc(Location loc,
|
||||
gpu::LaunchFuncOp launchOp);
|
||||
|
||||
};
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
void ConvertGpuLaunchFuncToVulkanLaunchFunc::runOnModule() {
|
||||
void ConvertGpuLaunchFuncToVulkanLaunchFunc::runOnOperation() {
|
||||
bool done = false;
|
||||
getModule().walk([this, &done](gpu::LaunchFuncOp op) {
|
||||
getOperation().walk([this, &done](gpu::LaunchFuncOp op) {
|
||||
if (done) {
|
||||
op.emitError("should only contain one 'gpu::LaunchFuncOp' op");
|
||||
return signalPassFailure();
|
||||
|
@ -86,17 +85,17 @@ void ConvertGpuLaunchFuncToVulkanLaunchFunc::runOnModule() {
|
|||
|
||||
// Erase `gpu::GPUModuleOp` and `spirv::Module` operations.
|
||||
for (auto gpuModule :
|
||||
llvm::make_early_inc_range(getModule().getOps<gpu::GPUModuleOp>()))
|
||||
llvm::make_early_inc_range(getOperation().getOps<gpu::GPUModuleOp>()))
|
||||
gpuModule.erase();
|
||||
|
||||
for (auto spirvModule :
|
||||
llvm::make_early_inc_range(getModule().getOps<spirv::ModuleOp>()))
|
||||
llvm::make_early_inc_range(getOperation().getOps<spirv::ModuleOp>()))
|
||||
spirvModule.erase();
|
||||
}
|
||||
|
||||
LogicalResult ConvertGpuLaunchFuncToVulkanLaunchFunc::declareVulkanLaunchFunc(
|
||||
Location loc, gpu::LaunchFuncOp launchOp) {
|
||||
OpBuilder builder(getModule().getBody()->getTerminator());
|
||||
OpBuilder builder(getOperation().getBody()->getTerminator());
|
||||
// TODO: Workgroup size is written into the kernel. So to properly modelling
|
||||
// vulkan launch, we cannot have the local workgroup size configuration here.
|
||||
SmallVector<Type, 8> vulkanLaunchTypes{launchOp.getOperandTypes()};
|
||||
|
@ -138,7 +137,7 @@ LogicalResult ConvertGpuLaunchFuncToVulkanLaunchFunc::createBinaryShader(
|
|||
|
||||
void ConvertGpuLaunchFuncToVulkanLaunchFunc::convertGpuLaunchFunc(
|
||||
gpu::LaunchFuncOp launchOp) {
|
||||
ModuleOp module = getModule();
|
||||
ModuleOp module = getOperation();
|
||||
OpBuilder builder(launchOp);
|
||||
Location loc = launchOp.getLoc();
|
||||
|
||||
|
|
|
@ -58,7 +58,7 @@ namespace {
|
|||
/// * deinitVulkan -- deinitializes vulkan runtime
|
||||
///
|
||||
class VulkanLaunchFuncToVulkanCallsPass
|
||||
: public ModulePass<VulkanLaunchFuncToVulkanCallsPass> {
|
||||
: public OperationPass<VulkanLaunchFuncToVulkanCallsPass, ModuleOp> {
|
||||
private:
|
||||
/// Include the generated pass utilities.
|
||||
#define GEN_PASS_ConvertVulkanLaunchFuncToVulkanCalls
|
||||
|
@ -150,7 +150,7 @@ private:
|
|||
LogicalResult deduceMemRefRank(Value ptrToMemRefDescriptor, uint32_t &rank);
|
||||
|
||||
public:
|
||||
void runOnModule() override;
|
||||
void runOnOperation() override;
|
||||
|
||||
private:
|
||||
LLVM::LLVMDialect *llvmDialect;
|
||||
|
@ -169,18 +169,18 @@ private:
|
|||
|
||||
} // anonymous namespace
|
||||
|
||||
void VulkanLaunchFuncToVulkanCallsPass::runOnModule() {
|
||||
void VulkanLaunchFuncToVulkanCallsPass::runOnOperation() {
|
||||
initializeCachedTypes();
|
||||
|
||||
// Collect SPIR-V attributes such as `spirv_blob` and
|
||||
// `spirv_entry_point_name`.
|
||||
getModule().walk([this](LLVM::CallOp op) {
|
||||
getOperation().walk([this](LLVM::CallOp op) {
|
||||
if (isVulkanLaunchCallOp(op))
|
||||
collectSPIRVAttributes(op);
|
||||
});
|
||||
|
||||
// Convert vulkan launch call op into a sequence of Vulkan runtime calls.
|
||||
getModule().walk([this](LLVM::CallOp op) {
|
||||
getOperation().walk([this](LLVM::CallOp op) {
|
||||
if (isCInterfaceVulkanLaunchCallOp(op))
|
||||
translateVulkanLaunchCall(op);
|
||||
});
|
||||
|
@ -278,7 +278,7 @@ VulkanLaunchFuncToVulkanCallsPass::deduceMemRefRank(Value ptrToMemRefDescriptor,
|
|||
}
|
||||
|
||||
void VulkanLaunchFuncToVulkanCallsPass::declareVulkanFunctions(Location loc) {
|
||||
ModuleOp module = getModule();
|
||||
ModuleOp module = getOperation();
|
||||
OpBuilder builder(module.getBody()->getTerminator());
|
||||
|
||||
if (!module.lookupSymbol(kSetEntryPoint)) {
|
||||
|
|
|
@ -561,17 +561,18 @@ void mlir::populateLinalgToLLVMConversionPatterns(
|
|||
}
|
||||
|
||||
namespace {
|
||||
struct ConvertLinalgToLLVMPass : public ModulePass<ConvertLinalgToLLVMPass> {
|
||||
struct ConvertLinalgToLLVMPass
|
||||
: public OperationPass<ConvertLinalgToLLVMPass, ModuleOp> {
|
||||
/// Include the generated pass utilities.
|
||||
#define GEN_PASS_ConvertLinalgToLLVM
|
||||
#include "mlir/Conversion/Passes.h.inc"
|
||||
|
||||
void runOnModule() override;
|
||||
void runOnOperation() override;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void ConvertLinalgToLLVMPass::runOnModule() {
|
||||
auto module = getModule();
|
||||
void ConvertLinalgToLLVMPass::runOnOperation() {
|
||||
auto module = getOperation();
|
||||
|
||||
// Convert to the LLVM IR dialect using the converter defined above.
|
||||
OwningRewritePatternList patterns;
|
||||
|
|
|
@ -16,18 +16,18 @@ using namespace mlir;
|
|||
|
||||
namespace {
|
||||
/// A pass converting MLIR Linalg ops into SPIR-V ops.
|
||||
class LinalgToSPIRVPass : public ModulePass<LinalgToSPIRVPass> {
|
||||
class LinalgToSPIRVPass : public OperationPass<LinalgToSPIRVPass, ModuleOp> {
|
||||
/// Include the generated pass utilities.
|
||||
#define GEN_PASS_ConvertLinalgToSPIRV
|
||||
#include "mlir/Conversion/Passes.h.inc"
|
||||
|
||||
void runOnModule() override;
|
||||
void runOnOperation() override;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void LinalgToSPIRVPass::runOnModule() {
|
||||
void LinalgToSPIRVPass::runOnOperation() {
|
||||
MLIRContext *context = &getContext();
|
||||
ModuleOp module = getModule();
|
||||
ModuleOp module = getOperation();
|
||||
|
||||
auto targetAttr = spirv::lookupTargetEnvOrDefault(module);
|
||||
std::unique_ptr<ConversionTarget> target =
|
||||
|
|
|
@ -2847,7 +2847,7 @@ LLVMTypeConverter::promoteMemRefDescriptors(Location loc, ValueRange opOperands,
|
|||
|
||||
namespace {
|
||||
/// A pass converting MLIR operations into the LLVM IR dialect.
|
||||
struct LLVMLoweringPass : public ModulePass<LLVMLoweringPass> {
|
||||
struct LLVMLoweringPass : public OperationPass<LLVMLoweringPass, ModuleOp> {
|
||||
/// Include the generated pass utilities.
|
||||
#define GEN_PASS_ConvertStandardToLLVM
|
||||
#include "mlir/Conversion/Passes.h.inc"
|
||||
|
@ -2863,16 +2863,16 @@ struct LLVMLoweringPass : public ModulePass<LLVMLoweringPass> {
|
|||
LLVMLoweringPass(const LLVMLoweringPass &pass) {}
|
||||
|
||||
/// Run the dialect converter on the module.
|
||||
void runOnModule() override {
|
||||
void runOnOperation() override {
|
||||
if (useBarePtrCallConv && emitCWrappers) {
|
||||
getModule().emitError()
|
||||
getOperation().emitError()
|
||||
<< "incompatible conversion options: bare-pointer calling convention "
|
||||
"and C wrapper emission";
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
||||
ModuleOp m = getModule();
|
||||
ModuleOp m = getOperation();
|
||||
|
||||
LLVMTypeConverterCustomization customs;
|
||||
customs.funcArgConverter = useBarePtrCallConv ? barePtrFuncArgTypeConverter
|
||||
|
|
|
@ -22,18 +22,18 @@ using namespace mlir;
|
|||
namespace {
|
||||
/// A pass converting MLIR Standard operations into the SPIR-V dialect.
|
||||
class ConvertStandardToSPIRVPass
|
||||
: public ModulePass<ConvertStandardToSPIRVPass> {
|
||||
: public OperationPass<ConvertStandardToSPIRVPass, ModuleOp> {
|
||||
/// Include the generated pass utilities.
|
||||
#define GEN_PASS_ConvertStandardToSPIRV
|
||||
#include "mlir/Conversion/Passes.h.inc"
|
||||
|
||||
void runOnModule() override;
|
||||
void runOnOperation() override;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void ConvertStandardToSPIRVPass::runOnModule() {
|
||||
void ConvertStandardToSPIRVPass::runOnOperation() {
|
||||
MLIRContext *context = &getContext();
|
||||
ModuleOp module = getModule();
|
||||
ModuleOp module = getOperation();
|
||||
|
||||
auto targetAttr = spirv::lookupTargetEnvOrDefault(module);
|
||||
std::unique_ptr<ConversionTarget> target =
|
||||
|
|
|
@ -1118,23 +1118,24 @@ void mlir::populateVectorToLLVMMatrixConversionPatterns(
|
|||
}
|
||||
|
||||
namespace {
|
||||
struct LowerVectorToLLVMPass : public ModulePass<LowerVectorToLLVMPass> {
|
||||
struct LowerVectorToLLVMPass
|
||||
: public OperationPass<LowerVectorToLLVMPass, ModuleOp> {
|
||||
/// Include the generated pass utilities.
|
||||
#define GEN_PASS_ConvertVectorToLLVM
|
||||
#include "mlir/Conversion/Passes.h.inc"
|
||||
|
||||
void runOnModule() override;
|
||||
void runOnOperation() override;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void LowerVectorToLLVMPass::runOnModule() {
|
||||
void LowerVectorToLLVMPass::runOnOperation() {
|
||||
// Perform progressive lowering of operations on slices and
|
||||
// all contraction operations. Also applies folding and DCE.
|
||||
{
|
||||
OwningRewritePatternList patterns;
|
||||
populateVectorSlicesLoweringPatterns(patterns, &getContext());
|
||||
populateVectorContractLoweringPatterns(patterns, &getContext());
|
||||
applyPatternsGreedily(getModule(), patterns);
|
||||
applyPatternsGreedily(getOperation(), patterns);
|
||||
}
|
||||
|
||||
// Convert to the LLVM IR dialect.
|
||||
|
@ -1148,8 +1149,8 @@ void LowerVectorToLLVMPass::runOnModule() {
|
|||
LLVMConversionTarget target(getContext());
|
||||
target.addDynamicallyLegalOp<FuncOp>(
|
||||
[&](FuncOp op) { return converter.isSignatureLegal(op.getType()); });
|
||||
if (failed(
|
||||
applyPartialConversion(getModule(), target, patterns, &converter))) {
|
||||
if (failed(applyPartialConversion(getOperation(), target, patterns,
|
||||
&converter))) {
|
||||
signalPassFailure();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -214,16 +214,17 @@ namespace {
|
|||
/// The gpu.modules are intended to be compiled to a cubin blob independently in
|
||||
/// a separate pass. The external functions can then be annotated with the
|
||||
/// symbol of the cubin accessor function.
|
||||
class GpuKernelOutliningPass : public ModulePass<GpuKernelOutliningPass> {
|
||||
class GpuKernelOutliningPass
|
||||
: public OperationPass<GpuKernelOutliningPass, ModuleOp> {
|
||||
public:
|
||||
/// Include the generated pass utilities.
|
||||
#define GEN_PASS_GpuKernelOutlining
|
||||
#include "mlir/Dialect/GPU/Passes.h.inc"
|
||||
|
||||
void runOnModule() override {
|
||||
SymbolTable symbolTable(getModule());
|
||||
void runOnOperation() override {
|
||||
SymbolTable symbolTable(getOperation());
|
||||
bool modified = false;
|
||||
for (auto func : getModule().getOps<FuncOp>()) {
|
||||
for (auto func : getOperation().getOps<FuncOp>()) {
|
||||
// Insert just after the function.
|
||||
Block::iterator insertPt(func.getOperation()->getNextNode());
|
||||
auto funcWalkResult = func.walk([&](gpu::LaunchOp op) {
|
||||
|
@ -255,8 +256,8 @@ public:
|
|||
// If any new module was inserted in this module, annotate this module as
|
||||
// a container module.
|
||||
if (modified)
|
||||
getModule().setAttr(gpu::GPUDialect::getContainerModuleAttrName(),
|
||||
UnitAttr::get(&getContext()));
|
||||
getOperation().setAttr(gpu::GPUDialect::getContainerModuleAttrName(),
|
||||
UnitAttr::get(&getContext()));
|
||||
}
|
||||
|
||||
private:
|
||||
|
@ -267,7 +268,7 @@ private:
|
|||
// a SymbolTable by the caller. SymbolTable needs to be refactored to
|
||||
// prevent manual building of Ops with symbols in code using SymbolTables
|
||||
// and then this needs to use the OpBuilder.
|
||||
auto context = getModule().getContext();
|
||||
auto context = getOperation().getContext();
|
||||
Builder builder(context);
|
||||
OperationState state(kernelFunc.getLoc(),
|
||||
gpu::GPUModuleOp::getOperationName());
|
||||
|
|
|
@ -80,14 +80,14 @@ static void populateSPIRVLayoutInfoPatterns(OwningRewritePatternList &patterns,
|
|||
|
||||
namespace {
|
||||
class DecorateSPIRVCompositeTypeLayoutPass
|
||||
: public ModulePass<DecorateSPIRVCompositeTypeLayoutPass> {
|
||||
: public OperationPass<DecorateSPIRVCompositeTypeLayoutPass, ModuleOp> {
|
||||
private:
|
||||
void runOnModule() override;
|
||||
void runOnOperation() override;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void DecorateSPIRVCompositeTypeLayoutPass::runOnModule() {
|
||||
auto module = getModule();
|
||||
void DecorateSPIRVCompositeTypeLayoutPass::runOnOperation() {
|
||||
auto module = getOperation();
|
||||
OwningRewritePatternList patterns;
|
||||
populateSPIRVLayoutInfoPatterns(patterns, module.getContext());
|
||||
ConversionTarget target(*(module.getContext()));
|
||||
|
|
|
@ -18,7 +18,7 @@
|
|||
using namespace mlir;
|
||||
|
||||
namespace {
|
||||
struct PrintOpStatsPass : public ModulePass<PrintOpStatsPass> {
|
||||
struct PrintOpStatsPass : public OperationPass<PrintOpStatsPass, ModuleOp> {
|
||||
/// Include the generated pass utilities.
|
||||
#define GEN_PASS_PrintOpStats
|
||||
#include "mlir/Transforms/Passes.h.inc"
|
||||
|
@ -26,7 +26,7 @@ struct PrintOpStatsPass : public ModulePass<PrintOpStatsPass> {
|
|||
explicit PrintOpStatsPass(raw_ostream &os = llvm::errs()) : os(os) {}
|
||||
|
||||
// Prints the resultant operation statistics post iterating over the module.
|
||||
void runOnModule() override;
|
||||
void runOnOperation() override;
|
||||
|
||||
// Print summary of op stats.
|
||||
void printSummary();
|
||||
|
@ -37,11 +37,11 @@ private:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
void PrintOpStatsPass::runOnModule() {
|
||||
void PrintOpStatsPass::runOnOperation() {
|
||||
opCount.clear();
|
||||
|
||||
// Compute the operation statistics for each function in the module.
|
||||
for (auto &op : getModule())
|
||||
for (auto &op : getOperation())
|
||||
op.walk([&](Operation *op) { ++opCount[op->getName().getStringRef()]; });
|
||||
printSummary();
|
||||
}
|
||||
|
|
|
@ -100,7 +100,7 @@ namespace {
|
|||
// PrintOpPass is simple pass to write graph per function.
|
||||
// Note: this is a module pass only to avoid interleaving on the same ostream
|
||||
// due to multi-threading over functions.
|
||||
struct PrintOpPass : public ModulePass<PrintOpPass> {
|
||||
struct PrintOpPass : public OperationPass<PrintOpPass, ModuleOp> {
|
||||
/// Include the generated pass utilities.
|
||||
#define GEN_PASS_PrintOpGraph
|
||||
#include "mlir/Transforms/Passes.h.inc"
|
||||
|
@ -140,7 +140,7 @@ struct PrintOpPass : public ModulePass<PrintOpPass> {
|
|||
}
|
||||
}
|
||||
|
||||
void runOnModule() override { processModule(getModule()); }
|
||||
void runOnOperation() override { processModule(getOperation()); }
|
||||
|
||||
private:
|
||||
raw_ostream &os;
|
||||
|
|
|
@ -398,13 +398,13 @@ struct TestTypeConverter : public TypeConverter {
|
|||
};
|
||||
|
||||
struct TestLegalizePatternDriver
|
||||
: public ModulePass<TestLegalizePatternDriver> {
|
||||
: public OperationPass<TestLegalizePatternDriver, ModuleOp> {
|
||||
/// The mode of conversion to use with the driver.
|
||||
enum class ConversionMode { Analysis, Full, Partial };
|
||||
|
||||
TestLegalizePatternDriver(ConversionMode mode) : mode(mode) {}
|
||||
|
||||
void runOnModule() override {
|
||||
void runOnOperation() override {
|
||||
TestTypeConverter converter;
|
||||
mlir::OwningRewritePatternList patterns;
|
||||
populateWithGenerated(&getContext(), &patterns);
|
||||
|
@ -450,7 +450,8 @@ struct TestLegalizePatternDriver
|
|||
|
||||
// Handle a partial conversion.
|
||||
if (mode == ConversionMode::Partial) {
|
||||
(void)applyPartialConversion(getModule(), target, patterns, &converter);
|
||||
(void)applyPartialConversion(getOperation(), target, patterns,
|
||||
&converter);
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -461,7 +462,7 @@ struct TestLegalizePatternDriver
|
|||
return (bool)op->getAttrOfType<UnitAttr>("test.dynamically_legal");
|
||||
});
|
||||
|
||||
(void)applyFullConversion(getModule(), target, patterns, &converter);
|
||||
(void)applyFullConversion(getOperation(), target, patterns, &converter);
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -470,7 +471,7 @@ struct TestLegalizePatternDriver
|
|||
|
||||
// Analyze the convertible operations.
|
||||
DenseSet<Operation *> legalizedOps;
|
||||
if (failed(applyAnalysisConversion(getModule(), target, patterns,
|
||||
if (failed(applyAnalysisConversion(getOperation(), target, patterns,
|
||||
legalizedOps, &converter)))
|
||||
return signalPassFailure();
|
||||
|
||||
|
|
|
@ -13,9 +13,9 @@ using namespace mlir;
|
|||
|
||||
namespace {
|
||||
/// This is a test pass for verifying FuncOp's eraseArgument method.
|
||||
struct TestFuncEraseArg : public ModulePass<TestFuncEraseArg> {
|
||||
void runOnModule() override {
|
||||
auto module = getModule();
|
||||
struct TestFuncEraseArg : public OperationPass<TestFuncEraseArg, ModuleOp> {
|
||||
void runOnOperation() override {
|
||||
auto module = getOperation();
|
||||
|
||||
for (FuncOp func : module.getOps<FuncOp>()) {
|
||||
SmallVector<unsigned, 4> indicesToErase;
|
||||
|
@ -36,9 +36,9 @@ struct TestFuncEraseArg : public ModulePass<TestFuncEraseArg> {
|
|||
};
|
||||
|
||||
/// This is a test pass for verifying FuncOp's setType method.
|
||||
struct TestFuncSetType : public ModulePass<TestFuncSetType> {
|
||||
void runOnModule() override {
|
||||
auto module = getModule();
|
||||
struct TestFuncSetType : public OperationPass<TestFuncSetType, ModuleOp> {
|
||||
void runOnOperation() override {
|
||||
auto module = getOperation();
|
||||
SymbolTable symbolTable(module);
|
||||
|
||||
for (FuncOp func : module.getOps<FuncOp>()) {
|
||||
|
|
|
@ -12,9 +12,9 @@
|
|||
using namespace mlir;
|
||||
|
||||
namespace {
|
||||
struct SideEffectsPass : public ModulePass<SideEffectsPass> {
|
||||
void runOnModule() override {
|
||||
auto module = getModule();
|
||||
struct SideEffectsPass : public OperationPass<SideEffectsPass, ModuleOp> {
|
||||
void runOnOperation() override {
|
||||
auto module = getOperation();
|
||||
|
||||
// Walk operations detecting side effects.
|
||||
SmallVector<MemoryEffects::EffectInstance, 8> effects;
|
||||
|
|
|
@ -15,7 +15,7 @@ using namespace mlir;
|
|||
namespace {
|
||||
/// This is a symbol test pass that tests the symbol uselist functionality
|
||||
/// provided by the symbol table along with erasing from the symbol table.
|
||||
struct SymbolUsesPass : public ModulePass<SymbolUsesPass> {
|
||||
struct SymbolUsesPass : public OperationPass<SymbolUsesPass, ModuleOp> {
|
||||
WalkResult operateOnSymbol(Operation *symbol, ModuleOp module,
|
||||
SmallVectorImpl<FuncOp> &deadFunctions) {
|
||||
// Test computing uses on a non symboltable op.
|
||||
|
@ -59,8 +59,8 @@ struct SymbolUsesPass : public ModulePass<SymbolUsesPass> {
|
|||
return WalkResult::advance();
|
||||
}
|
||||
|
||||
void runOnModule() override {
|
||||
auto module = getModule();
|
||||
void runOnOperation() override {
|
||||
auto module = getOperation();
|
||||
|
||||
// Walk nested symbols.
|
||||
SmallVector<FuncOp, 4> deadFunctions;
|
||||
|
@ -86,9 +86,10 @@ struct SymbolUsesPass : public ModulePass<SymbolUsesPass> {
|
|||
|
||||
/// This is a symbol test pass that tests the symbol use replacement
|
||||
/// functionality provided by the symbol table.
|
||||
struct SymbolReplacementPass : public ModulePass<SymbolReplacementPass> {
|
||||
void runOnModule() override {
|
||||
auto module = getModule();
|
||||
struct SymbolReplacementPass
|
||||
: public OperationPass<SymbolReplacementPass, ModuleOp> {
|
||||
void runOnOperation() override {
|
||||
auto module = getOperation();
|
||||
|
||||
// Walk nested functions and modules.
|
||||
module.getBodyRegion().walk([&](Operation *nestedOp) {
|
||||
|
|
|
@ -13,8 +13,8 @@
|
|||
using namespace mlir;
|
||||
|
||||
namespace {
|
||||
struct TestModulePass : public ModulePass<TestModulePass> {
|
||||
void runOnModule() final {}
|
||||
struct TestModulePass : public OperationPass<TestModulePass, ModuleOp> {
|
||||
void runOnOperation() final {}
|
||||
};
|
||||
struct TestFunctionPass : public FunctionPass<TestFunctionPass> {
|
||||
void runOnFunction() final {}
|
||||
|
|
|
@ -18,11 +18,11 @@ using namespace mlir;
|
|||
|
||||
namespace {
|
||||
struct TestAllReduceLoweringPass
|
||||
: public ModulePass<TestAllReduceLoweringPass> {
|
||||
void runOnModule() override {
|
||||
: public OperationPass<TestAllReduceLoweringPass, ModuleOp> {
|
||||
void runOnOperation() override {
|
||||
OwningRewritePatternList patterns;
|
||||
populateGpuRewritePatterns(&getContext(), patterns);
|
||||
applyPatternsGreedily(getModule(), patterns);
|
||||
applyPatternsGreedily(getOperation(), patterns);
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
|
|
@ -17,9 +17,9 @@
|
|||
using namespace mlir;
|
||||
|
||||
namespace {
|
||||
struct TestCallGraphPass : public ModulePass<TestCallGraphPass> {
|
||||
void runOnModule() {
|
||||
llvm::errs() << "Testing : " << getModule().getAttr("test.name") << "\n";
|
||||
struct TestCallGraphPass : public OperationPass<TestCallGraphPass, ModuleOp> {
|
||||
void runOnOperation() override {
|
||||
llvm::errs() << "Testing : " << getOperation().getAttr("test.name") << "\n";
|
||||
getAnalysis<CallGraph>().print(llvm::errs());
|
||||
}
|
||||
};
|
||||
|
|
|
@ -17,7 +17,7 @@ namespace {
|
|||
/// It also takes all operations that are not function operations or
|
||||
/// terminators and clones them with opaque locations which store the initial
|
||||
/// locations.
|
||||
struct TestOpaqueLoc : public ModulePass<TestOpaqueLoc> {
|
||||
struct TestOpaqueLoc : public OperationPass<TestOpaqueLoc, ModuleOp> {
|
||||
|
||||
/// A simple structure which is used for testing as an underlying location in
|
||||
/// OpaqueLoc.
|
||||
|
@ -29,11 +29,11 @@ struct TestOpaqueLoc : public ModulePass<TestOpaqueLoc> {
|
|||
int id;
|
||||
};
|
||||
|
||||
void runOnModule() override {
|
||||
void runOnOperation() override {
|
||||
std::vector<std::unique_ptr<MyLocation>> myLocs;
|
||||
int last_it = 0;
|
||||
|
||||
getModule().walk([&](Operation *op) {
|
||||
getOperation().walk([&](Operation *op) {
|
||||
myLocs.push_back(std::make_unique<MyLocation>(last_it++));
|
||||
|
||||
Location loc = op->getLoc();
|
||||
|
@ -74,7 +74,7 @@ struct TestOpaqueLoc : public ModulePass<TestOpaqueLoc> {
|
|||
os.flush();
|
||||
});
|
||||
|
||||
getModule().walk([&](Operation *op) { op->emitOpError(); });
|
||||
getOperation().walk([&](Operation *op) { op->emitOpError(); });
|
||||
}
|
||||
};
|
||||
|
||||
|
|
Loading…
Reference in New Issue