[mlir][Pass][NFC] Replace usages of ModulePass with OperationPass<ModuleOp>

ModulePass doesn't provide any special utilities and thus doesn't give enough benefit to warrant a special pass class. This revision replaces all usages with the more general OperationPass.

Differential Revision: https://reviews.llvm.org/D77339
This commit is contained in:
River Riddle 2020-04-07 13:55:34 -07:00
parent 2481f26ac3
commit 722f909f7a
26 changed files with 124 additions and 133 deletions

View File

@ -105,7 +105,7 @@ We want to completely lower to LLVM, so we use a `FullConversion`. This ensures
that only legal operations will remain after the conversion.
```c++
mlir::ModuleOp module = getModule();
mlir::ModuleOp module = getOperation();
if (mlir::failed(mlir::applyFullConversion(module, target, patterns,
&typeConverter)))
signalPassFailure();

View File

@ -153,12 +153,13 @@ private:
//===----------------------------------------------------------------------===//
namespace {
struct ToyToLLVMLoweringPass : public ModulePass<ToyToLLVMLoweringPass> {
void runOnModule() final;
struct ToyToLLVMLoweringPass
: public OperationPass<ToyToLLVMLoweringPass, ModuleOp> {
void runOnOperation() final;
};
} // end anonymous namespace
void ToyToLLVMLoweringPass::runOnModule() {
void ToyToLLVMLoweringPass::runOnOperation() {
// The first thing to define is the conversion target. This will define the
// final target for this lowering. For this lowering, we are only targeting
// the LLVM dialect.
@ -191,7 +192,7 @@ void ToyToLLVMLoweringPass::runOnModule() {
// We want to completely lower to LLVM, so we use a `FullConversion`. This
// ensures that only legal operations will remain after the conversion.
auto module = getModule();
auto module = getOperation();
if (failed(applyFullConversion(module, target, patterns, &typeConverter)))
signalPassFailure();
}

View File

@ -153,12 +153,13 @@ private:
//===----------------------------------------------------------------------===//
namespace {
struct ToyToLLVMLoweringPass : public ModulePass<ToyToLLVMLoweringPass> {
void runOnModule() final;
struct ToyToLLVMLoweringPass
: public OperationPass<ToyToLLVMLoweringPass, ModuleOp> {
void runOnOperation() final;
};
} // end anonymous namespace
void ToyToLLVMLoweringPass::runOnModule() {
void ToyToLLVMLoweringPass::runOnOperation() {
// The first thing to define is the conversion target. This will define the
// final target for this lowering. For this lowering, we are only targeting
// the LLVM dialect.
@ -191,7 +192,7 @@ void ToyToLLVMLoweringPass::runOnModule() {
// We want to completely lower to LLVM, so we use a `FullConversion`. This
// ensures that only legal operations will remain after the conversion.
auto module = getModule();
auto module = getOperation();
if (failed(applyFullConversion(module, target, patterns, &typeConverter)))
signalPassFailure();
}

View File

@ -341,24 +341,9 @@ template <typename T> struct FunctionPass : public OperationPass<T, FuncOp> {
runOnFunction();
}
/// Return the current module being transformed.
/// Return the current function being transformed.
FuncOp getFunction() { return this->getOperation(); }
};
/// A model for providing module pass specific utilities.
///
/// Derived module passes are expected to provide the following:
/// - A 'void runOnModule()' method.
template <typename T> struct ModulePass : public OperationPass<T, ModuleOp> {
/// The polymorphic API that runs the pass over the currently held module.
virtual void runOnModule() = 0;
/// The polymorphic API that runs the pass over the currently held operation.
void runOnOperation() final { runOnModule(); }
/// Return the current module being transformed.
ModuleOp getModule() { return this->getOperation(); }
};
} // end namespace mlir
#endif // MLIR_PASS_PASS_H

View File

@ -163,16 +163,17 @@ void mlir::populateAVX512ToLLVMConversionPatterns(
}
namespace {
struct ConvertAVX512ToLLVMPass : public ModulePass<ConvertAVX512ToLLVMPass> {
struct ConvertAVX512ToLLVMPass
: public OperationPass<ConvertAVX512ToLLVMPass, ModuleOp> {
/// Include the generated pass utilities.
#define GEN_PASS_ConvertAVX512ToLLVM
#include "mlir/Conversion/Passes.h.inc"
void runOnModule() override;
void runOnOperation() override;
};
} // namespace
void ConvertAVX512ToLLVMPass::runOnModule() {
void ConvertAVX512ToLLVMPass::runOnOperation() {
// Convert to the LLVM IR dialect.
OwningRewritePatternList patterns;
LLVMTypeConverter converter(&getContext());
@ -186,8 +187,8 @@ void ConvertAVX512ToLLVMPass::runOnModule() {
target.addIllegalDialect<avx512::AVX512Dialect>();
target.addDynamicallyLegalOp<FuncOp>(
[&](FuncOp op) { return converter.isSignatureLegal(op.getType()); });
if (failed(
applyPartialConversion(getModule(), target, patterns, &converter))) {
if (failed(applyPartialConversion(getOperation(), target, patterns,
&converter))) {
signalPassFailure();
}
}

View File

@ -61,7 +61,7 @@ namespace {
///
/// Intermediate data structures are allocated on the stack.
class GpuLaunchFuncToCudaCallsPass
: public ModulePass<GpuLaunchFuncToCudaCallsPass> {
: public OperationPass<GpuLaunchFuncToCudaCallsPass, ModuleOp> {
private:
/// Include the generated pass utilities.
#define GEN_PASS_ConvertGpuLaunchFuncToCudaCalls
@ -126,20 +126,19 @@ private:
public:
// Run the dialect converter on the module.
void runOnModule() override {
void runOnOperation() override {
// Cache the LLVMDialect for the current module.
llvmDialect = getContext().getRegisteredDialect<LLVM::LLVMDialect>();
// Cache the used LLVM types.
initializeCachedTypes();
getModule().walk([this](mlir::gpu::LaunchFuncOp op) {
translateGpuLaunchCalls(op);
});
getOperation().walk(
[this](mlir::gpu::LaunchFuncOp op) { translateGpuLaunchCalls(op); });
// GPU kernel modules are no longer necessary since we have a global
// constant with the CUBIN data.
for (auto m :
llvm::make_early_inc_range(getModule().getOps<gpu::GPUModuleOp>()))
llvm::make_early_inc_range(getOperation().getOps<gpu::GPUModuleOp>()))
m.erase();
}
@ -160,7 +159,7 @@ private:
// The types in comments give the actual types expected/returned but the API
// uses void pointers. This is fine as they have the same linkage in C.
void GpuLaunchFuncToCudaCallsPass::declareCudaFunctions(Location loc) {
ModuleOp module = getModule();
ModuleOp module = getOperation();
OpBuilder builder(module.getBody()->getTerminator());
if (!module.lookupSymbol(cuModuleLoadName)) {
builder.create<LLVM::LLVMFuncOp>(
@ -391,7 +390,7 @@ void GpuLaunchFuncToCudaCallsPass::translateGpuLaunchCalls(
builder.getI32IntegerAttr(0));
// Create an LLVM global with CUBIN extracted from the kernel annotation and
// obtain a pointer to the first byte in it.
auto kernelModule = getModule().lookupSymbol<gpu::GPUModuleOp>(
auto kernelModule = getOperation().lookupSymbol<gpu::GPUModuleOp>(
launchOp.getKernelModuleName());
assert(kernelModule && "expected a kernel module");
@ -412,7 +411,7 @@ void GpuLaunchFuncToCudaCallsPass::translateGpuLaunchCalls(
// in the called helper function.
auto cuModule = allocatePointer(builder, loc);
auto cuModuleLoad =
getModule().lookupSymbol<LLVM::LLVMFuncOp>(cuModuleLoadName);
getOperation().lookupSymbol<LLVM::LLVMFuncOp>(cuModuleLoadName);
builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getCUResultType()},
builder.getSymbolRefAttr(cuModuleLoad),
ArrayRef<Value>{cuModule, data});
@ -423,20 +422,20 @@ void GpuLaunchFuncToCudaCallsPass::translateGpuLaunchCalls(
auto kernelName = generateKernelNameConstant(launchOp.kernel(), loc, builder);
auto cuFunction = allocatePointer(builder, loc);
auto cuModuleGetFunction =
getModule().lookupSymbol<LLVM::LLVMFuncOp>(cuModuleGetFunctionName);
getOperation().lookupSymbol<LLVM::LLVMFuncOp>(cuModuleGetFunctionName);
builder.create<LLVM::CallOp>(
loc, ArrayRef<Type>{getCUResultType()},
builder.getSymbolRefAttr(cuModuleGetFunction),
ArrayRef<Value>{cuFunction, cuOwningModuleRef, kernelName});
// Grab the global stream needed for execution.
auto cuGetStreamHelper =
getModule().lookupSymbol<LLVM::LLVMFuncOp>(cuGetStreamHelperName);
getOperation().lookupSymbol<LLVM::LLVMFuncOp>(cuGetStreamHelperName);
auto cuStream = builder.create<LLVM::CallOp>(
loc, ArrayRef<Type>{getPointerType()},
builder.getSymbolRefAttr(cuGetStreamHelper), ArrayRef<Value>{});
// Invoke the function with required arguments.
auto cuLaunchKernel =
getModule().lookupSymbol<LLVM::LLVMFuncOp>(cuLaunchKernelName);
getOperation().lookupSymbol<LLVM::LLVMFuncOp>(cuLaunchKernelName);
auto cuFunctionRef =
builder.create<LLVM::LoadOp>(loc, getPointerType(), cuFunction);
auto paramsArray = setupParamsArray(launchOp, builder);
@ -458,7 +457,7 @@ void GpuLaunchFuncToCudaCallsPass::translateGpuLaunchCalls(
nullpointer /* extra */});
// Sync on the stream to make it synchronous.
auto cuStreamSync =
getModule().lookupSymbol<LLVM::LLVMFuncOp>(cuStreamSynchronizeName);
getOperation().lookupSymbol<LLVM::LLVMFuncOp>(cuStreamSynchronizeName);
builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getCUResultType()},
builder.getSymbolRefAttr(cuStreamSync),
ArrayRef<Value>(cuStream.getResult(0)));

View File

@ -33,18 +33,18 @@ namespace {
/// replace it).
///
/// 2) Lower the body of the spirv::ModuleOp.
struct GPUToSPIRVPass : public ModulePass<GPUToSPIRVPass> {
struct GPUToSPIRVPass : public OperationPass<GPUToSPIRVPass, ModuleOp> {
/// Include the generated pass utilities.
#define GEN_PASS_ConvertGpuToSPIRV
#include "mlir/Conversion/Passes.h.inc"
void runOnModule() override;
void runOnOperation() override;
};
} // namespace
void GPUToSPIRVPass::runOnModule() {
void GPUToSPIRVPass::runOnOperation() {
MLIRContext *context = &getContext();
ModuleOp module = getModule();
ModuleOp module = getOperation();
SmallVector<Operation *, 1> kernelModules;
OpBuilder builder(context);

View File

@ -38,13 +38,13 @@ namespace {
/// function and attaching binary data and entry point name as an attributes to
/// created vulkan launch call op.
class ConvertGpuLaunchFuncToVulkanLaunchFunc
: public ModulePass<ConvertGpuLaunchFuncToVulkanLaunchFunc> {
: public OperationPass<ConvertGpuLaunchFuncToVulkanLaunchFunc, ModuleOp> {
public:
/// Include the generated pass utilities.
#define GEN_PASS_ConvertGpuLaunchFuncToVulkanLaunchFunc
#include "mlir/Conversion/Passes.h.inc"
void runOnModule() override;
void runOnOperation() override;
private:
/// Creates a SPIR-V binary shader from the given `module` using
@ -68,14 +68,13 @@ private:
/// operand is unsupported by Vulkan runtime.
LogicalResult declareVulkanLaunchFunc(Location loc,
gpu::LaunchFuncOp launchOp);
};
} // anonymous namespace
void ConvertGpuLaunchFuncToVulkanLaunchFunc::runOnModule() {
void ConvertGpuLaunchFuncToVulkanLaunchFunc::runOnOperation() {
bool done = false;
getModule().walk([this, &done](gpu::LaunchFuncOp op) {
getOperation().walk([this, &done](gpu::LaunchFuncOp op) {
if (done) {
op.emitError("should only contain one 'gpu::LaunchFuncOp' op");
return signalPassFailure();
@ -86,17 +85,17 @@ void ConvertGpuLaunchFuncToVulkanLaunchFunc::runOnModule() {
// Erase `gpu::GPUModuleOp` and `spirv::Module` operations.
for (auto gpuModule :
llvm::make_early_inc_range(getModule().getOps<gpu::GPUModuleOp>()))
llvm::make_early_inc_range(getOperation().getOps<gpu::GPUModuleOp>()))
gpuModule.erase();
for (auto spirvModule :
llvm::make_early_inc_range(getModule().getOps<spirv::ModuleOp>()))
llvm::make_early_inc_range(getOperation().getOps<spirv::ModuleOp>()))
spirvModule.erase();
}
LogicalResult ConvertGpuLaunchFuncToVulkanLaunchFunc::declareVulkanLaunchFunc(
Location loc, gpu::LaunchFuncOp launchOp) {
OpBuilder builder(getModule().getBody()->getTerminator());
OpBuilder builder(getOperation().getBody()->getTerminator());
// TODO: Workgroup size is written into the kernel. So to properly modelling
// vulkan launch, we cannot have the local workgroup size configuration here.
SmallVector<Type, 8> vulkanLaunchTypes{launchOp.getOperandTypes()};
@ -138,7 +137,7 @@ LogicalResult ConvertGpuLaunchFuncToVulkanLaunchFunc::createBinaryShader(
void ConvertGpuLaunchFuncToVulkanLaunchFunc::convertGpuLaunchFunc(
gpu::LaunchFuncOp launchOp) {
ModuleOp module = getModule();
ModuleOp module = getOperation();
OpBuilder builder(launchOp);
Location loc = launchOp.getLoc();

View File

@ -58,7 +58,7 @@ namespace {
/// * deinitVulkan -- deinitializes vulkan runtime
///
class VulkanLaunchFuncToVulkanCallsPass
: public ModulePass<VulkanLaunchFuncToVulkanCallsPass> {
: public OperationPass<VulkanLaunchFuncToVulkanCallsPass, ModuleOp> {
private:
/// Include the generated pass utilities.
#define GEN_PASS_ConvertVulkanLaunchFuncToVulkanCalls
@ -150,7 +150,7 @@ private:
LogicalResult deduceMemRefRank(Value ptrToMemRefDescriptor, uint32_t &rank);
public:
void runOnModule() override;
void runOnOperation() override;
private:
LLVM::LLVMDialect *llvmDialect;
@ -169,18 +169,18 @@ private:
} // anonymous namespace
void VulkanLaunchFuncToVulkanCallsPass::runOnModule() {
void VulkanLaunchFuncToVulkanCallsPass::runOnOperation() {
initializeCachedTypes();
// Collect SPIR-V attributes such as `spirv_blob` and
// `spirv_entry_point_name`.
getModule().walk([this](LLVM::CallOp op) {
getOperation().walk([this](LLVM::CallOp op) {
if (isVulkanLaunchCallOp(op))
collectSPIRVAttributes(op);
});
// Convert vulkan launch call op into a sequence of Vulkan runtime calls.
getModule().walk([this](LLVM::CallOp op) {
getOperation().walk([this](LLVM::CallOp op) {
if (isCInterfaceVulkanLaunchCallOp(op))
translateVulkanLaunchCall(op);
});
@ -278,7 +278,7 @@ VulkanLaunchFuncToVulkanCallsPass::deduceMemRefRank(Value ptrToMemRefDescriptor,
}
void VulkanLaunchFuncToVulkanCallsPass::declareVulkanFunctions(Location loc) {
ModuleOp module = getModule();
ModuleOp module = getOperation();
OpBuilder builder(module.getBody()->getTerminator());
if (!module.lookupSymbol(kSetEntryPoint)) {

View File

@ -561,17 +561,18 @@ void mlir::populateLinalgToLLVMConversionPatterns(
}
namespace {
struct ConvertLinalgToLLVMPass : public ModulePass<ConvertLinalgToLLVMPass> {
struct ConvertLinalgToLLVMPass
: public OperationPass<ConvertLinalgToLLVMPass, ModuleOp> {
/// Include the generated pass utilities.
#define GEN_PASS_ConvertLinalgToLLVM
#include "mlir/Conversion/Passes.h.inc"
void runOnModule() override;
void runOnOperation() override;
};
} // namespace
void ConvertLinalgToLLVMPass::runOnModule() {
auto module = getModule();
void ConvertLinalgToLLVMPass::runOnOperation() {
auto module = getOperation();
// Convert to the LLVM IR dialect using the converter defined above.
OwningRewritePatternList patterns;

View File

@ -16,18 +16,18 @@ using namespace mlir;
namespace {
/// A pass converting MLIR Linalg ops into SPIR-V ops.
class LinalgToSPIRVPass : public ModulePass<LinalgToSPIRVPass> {
class LinalgToSPIRVPass : public OperationPass<LinalgToSPIRVPass, ModuleOp> {
/// Include the generated pass utilities.
#define GEN_PASS_ConvertLinalgToSPIRV
#include "mlir/Conversion/Passes.h.inc"
void runOnModule() override;
void runOnOperation() override;
};
} // namespace
void LinalgToSPIRVPass::runOnModule() {
void LinalgToSPIRVPass::runOnOperation() {
MLIRContext *context = &getContext();
ModuleOp module = getModule();
ModuleOp module = getOperation();
auto targetAttr = spirv::lookupTargetEnvOrDefault(module);
std::unique_ptr<ConversionTarget> target =

View File

@ -2847,7 +2847,7 @@ LLVMTypeConverter::promoteMemRefDescriptors(Location loc, ValueRange opOperands,
namespace {
/// A pass converting MLIR operations into the LLVM IR dialect.
struct LLVMLoweringPass : public ModulePass<LLVMLoweringPass> {
struct LLVMLoweringPass : public OperationPass<LLVMLoweringPass, ModuleOp> {
/// Include the generated pass utilities.
#define GEN_PASS_ConvertStandardToLLVM
#include "mlir/Conversion/Passes.h.inc"
@ -2863,16 +2863,16 @@ struct LLVMLoweringPass : public ModulePass<LLVMLoweringPass> {
LLVMLoweringPass(const LLVMLoweringPass &pass) {}
/// Run the dialect converter on the module.
void runOnModule() override {
void runOnOperation() override {
if (useBarePtrCallConv && emitCWrappers) {
getModule().emitError()
getOperation().emitError()
<< "incompatible conversion options: bare-pointer calling convention "
"and C wrapper emission";
signalPassFailure();
return;
}
ModuleOp m = getModule();
ModuleOp m = getOperation();
LLVMTypeConverterCustomization customs;
customs.funcArgConverter = useBarePtrCallConv ? barePtrFuncArgTypeConverter

View File

@ -22,18 +22,18 @@ using namespace mlir;
namespace {
/// A pass converting MLIR Standard operations into the SPIR-V dialect.
class ConvertStandardToSPIRVPass
: public ModulePass<ConvertStandardToSPIRVPass> {
: public OperationPass<ConvertStandardToSPIRVPass, ModuleOp> {
/// Include the generated pass utilities.
#define GEN_PASS_ConvertStandardToSPIRV
#include "mlir/Conversion/Passes.h.inc"
void runOnModule() override;
void runOnOperation() override;
};
} // namespace
void ConvertStandardToSPIRVPass::runOnModule() {
void ConvertStandardToSPIRVPass::runOnOperation() {
MLIRContext *context = &getContext();
ModuleOp module = getModule();
ModuleOp module = getOperation();
auto targetAttr = spirv::lookupTargetEnvOrDefault(module);
std::unique_ptr<ConversionTarget> target =

View File

@ -1118,23 +1118,24 @@ void mlir::populateVectorToLLVMMatrixConversionPatterns(
}
namespace {
struct LowerVectorToLLVMPass : public ModulePass<LowerVectorToLLVMPass> {
struct LowerVectorToLLVMPass
: public OperationPass<LowerVectorToLLVMPass, ModuleOp> {
/// Include the generated pass utilities.
#define GEN_PASS_ConvertVectorToLLVM
#include "mlir/Conversion/Passes.h.inc"
void runOnModule() override;
void runOnOperation() override;
};
} // namespace
void LowerVectorToLLVMPass::runOnModule() {
void LowerVectorToLLVMPass::runOnOperation() {
// Perform progressive lowering of operations on slices and
// all contraction operations. Also applies folding and DCE.
{
OwningRewritePatternList patterns;
populateVectorSlicesLoweringPatterns(patterns, &getContext());
populateVectorContractLoweringPatterns(patterns, &getContext());
applyPatternsGreedily(getModule(), patterns);
applyPatternsGreedily(getOperation(), patterns);
}
// Convert to the LLVM IR dialect.
@ -1148,8 +1149,8 @@ void LowerVectorToLLVMPass::runOnModule() {
LLVMConversionTarget target(getContext());
target.addDynamicallyLegalOp<FuncOp>(
[&](FuncOp op) { return converter.isSignatureLegal(op.getType()); });
if (failed(
applyPartialConversion(getModule(), target, patterns, &converter))) {
if (failed(applyPartialConversion(getOperation(), target, patterns,
&converter))) {
signalPassFailure();
}
}

View File

@ -214,16 +214,17 @@ namespace {
/// The gpu.modules are intended to be compiled to a cubin blob independently in
/// a separate pass. The external functions can then be annotated with the
/// symbol of the cubin accessor function.
class GpuKernelOutliningPass : public ModulePass<GpuKernelOutliningPass> {
class GpuKernelOutliningPass
: public OperationPass<GpuKernelOutliningPass, ModuleOp> {
public:
/// Include the generated pass utilities.
#define GEN_PASS_GpuKernelOutlining
#include "mlir/Dialect/GPU/Passes.h.inc"
void runOnModule() override {
SymbolTable symbolTable(getModule());
void runOnOperation() override {
SymbolTable symbolTable(getOperation());
bool modified = false;
for (auto func : getModule().getOps<FuncOp>()) {
for (auto func : getOperation().getOps<FuncOp>()) {
// Insert just after the function.
Block::iterator insertPt(func.getOperation()->getNextNode());
auto funcWalkResult = func.walk([&](gpu::LaunchOp op) {
@ -255,8 +256,8 @@ public:
// If any new module was inserted in this module, annotate this module as
// a container module.
if (modified)
getModule().setAttr(gpu::GPUDialect::getContainerModuleAttrName(),
UnitAttr::get(&getContext()));
getOperation().setAttr(gpu::GPUDialect::getContainerModuleAttrName(),
UnitAttr::get(&getContext()));
}
private:
@ -267,7 +268,7 @@ private:
// a SymbolTable by the caller. SymbolTable needs to be refactored to
// prevent manual building of Ops with symbols in code using SymbolTables
// and then this needs to use the OpBuilder.
auto context = getModule().getContext();
auto context = getOperation().getContext();
Builder builder(context);
OperationState state(kernelFunc.getLoc(),
gpu::GPUModuleOp::getOperationName());

View File

@ -80,14 +80,14 @@ static void populateSPIRVLayoutInfoPatterns(OwningRewritePatternList &patterns,
namespace {
class DecorateSPIRVCompositeTypeLayoutPass
: public ModulePass<DecorateSPIRVCompositeTypeLayoutPass> {
: public OperationPass<DecorateSPIRVCompositeTypeLayoutPass, ModuleOp> {
private:
void runOnModule() override;
void runOnOperation() override;
};
} // namespace
void DecorateSPIRVCompositeTypeLayoutPass::runOnModule() {
auto module = getModule();
void DecorateSPIRVCompositeTypeLayoutPass::runOnOperation() {
auto module = getOperation();
OwningRewritePatternList patterns;
populateSPIRVLayoutInfoPatterns(patterns, module.getContext());
ConversionTarget target(*(module.getContext()));

View File

@ -18,7 +18,7 @@
using namespace mlir;
namespace {
struct PrintOpStatsPass : public ModulePass<PrintOpStatsPass> {
struct PrintOpStatsPass : public OperationPass<PrintOpStatsPass, ModuleOp> {
/// Include the generated pass utilities.
#define GEN_PASS_PrintOpStats
#include "mlir/Transforms/Passes.h.inc"
@ -26,7 +26,7 @@ struct PrintOpStatsPass : public ModulePass<PrintOpStatsPass> {
explicit PrintOpStatsPass(raw_ostream &os = llvm::errs()) : os(os) {}
// Prints the resultant operation statistics post iterating over the module.
void runOnModule() override;
void runOnOperation() override;
// Print summary of op stats.
void printSummary();
@ -37,11 +37,11 @@ private:
};
} // namespace
void PrintOpStatsPass::runOnModule() {
void PrintOpStatsPass::runOnOperation() {
opCount.clear();
// Compute the operation statistics for each function in the module.
for (auto &op : getModule())
for (auto &op : getOperation())
op.walk([&](Operation *op) { ++opCount[op->getName().getStringRef()]; });
printSummary();
}

View File

@ -100,7 +100,7 @@ namespace {
// PrintOpPass is simple pass to write graph per function.
// Note: this is a module pass only to avoid interleaving on the same ostream
// due to multi-threading over functions.
struct PrintOpPass : public ModulePass<PrintOpPass> {
struct PrintOpPass : public OperationPass<PrintOpPass, ModuleOp> {
/// Include the generated pass utilities.
#define GEN_PASS_PrintOpGraph
#include "mlir/Transforms/Passes.h.inc"
@ -140,7 +140,7 @@ struct PrintOpPass : public ModulePass<PrintOpPass> {
}
}
void runOnModule() override { processModule(getModule()); }
void runOnOperation() override { processModule(getOperation()); }
private:
raw_ostream &os;

View File

@ -398,13 +398,13 @@ struct TestTypeConverter : public TypeConverter {
};
struct TestLegalizePatternDriver
: public ModulePass<TestLegalizePatternDriver> {
: public OperationPass<TestLegalizePatternDriver, ModuleOp> {
/// The mode of conversion to use with the driver.
enum class ConversionMode { Analysis, Full, Partial };
TestLegalizePatternDriver(ConversionMode mode) : mode(mode) {}
void runOnModule() override {
void runOnOperation() override {
TestTypeConverter converter;
mlir::OwningRewritePatternList patterns;
populateWithGenerated(&getContext(), &patterns);
@ -450,7 +450,8 @@ struct TestLegalizePatternDriver
// Handle a partial conversion.
if (mode == ConversionMode::Partial) {
(void)applyPartialConversion(getModule(), target, patterns, &converter);
(void)applyPartialConversion(getOperation(), target, patterns,
&converter);
return;
}
@ -461,7 +462,7 @@ struct TestLegalizePatternDriver
return (bool)op->getAttrOfType<UnitAttr>("test.dynamically_legal");
});
(void)applyFullConversion(getModule(), target, patterns, &converter);
(void)applyFullConversion(getOperation(), target, patterns, &converter);
return;
}
@ -470,7 +471,7 @@ struct TestLegalizePatternDriver
// Analyze the convertible operations.
DenseSet<Operation *> legalizedOps;
if (failed(applyAnalysisConversion(getModule(), target, patterns,
if (failed(applyAnalysisConversion(getOperation(), target, patterns,
legalizedOps, &converter)))
return signalPassFailure();

View File

@ -13,9 +13,9 @@ using namespace mlir;
namespace {
/// This is a test pass for verifying FuncOp's eraseArgument method.
struct TestFuncEraseArg : public ModulePass<TestFuncEraseArg> {
void runOnModule() override {
auto module = getModule();
struct TestFuncEraseArg : public OperationPass<TestFuncEraseArg, ModuleOp> {
void runOnOperation() override {
auto module = getOperation();
for (FuncOp func : module.getOps<FuncOp>()) {
SmallVector<unsigned, 4> indicesToErase;
@ -36,9 +36,9 @@ struct TestFuncEraseArg : public ModulePass<TestFuncEraseArg> {
};
/// This is a test pass for verifying FuncOp's setType method.
struct TestFuncSetType : public ModulePass<TestFuncSetType> {
void runOnModule() override {
auto module = getModule();
struct TestFuncSetType : public OperationPass<TestFuncSetType, ModuleOp> {
void runOnOperation() override {
auto module = getOperation();
SymbolTable symbolTable(module);
for (FuncOp func : module.getOps<FuncOp>()) {

View File

@ -12,9 +12,9 @@
using namespace mlir;
namespace {
struct SideEffectsPass : public ModulePass<SideEffectsPass> {
void runOnModule() override {
auto module = getModule();
struct SideEffectsPass : public OperationPass<SideEffectsPass, ModuleOp> {
void runOnOperation() override {
auto module = getOperation();
// Walk operations detecting side effects.
SmallVector<MemoryEffects::EffectInstance, 8> effects;

View File

@ -15,7 +15,7 @@ using namespace mlir;
namespace {
/// This is a symbol test pass that tests the symbol uselist functionality
/// provided by the symbol table along with erasing from the symbol table.
struct SymbolUsesPass : public ModulePass<SymbolUsesPass> {
struct SymbolUsesPass : public OperationPass<SymbolUsesPass, ModuleOp> {
WalkResult operateOnSymbol(Operation *symbol, ModuleOp module,
SmallVectorImpl<FuncOp> &deadFunctions) {
// Test computing uses on a non symboltable op.
@ -59,8 +59,8 @@ struct SymbolUsesPass : public ModulePass<SymbolUsesPass> {
return WalkResult::advance();
}
void runOnModule() override {
auto module = getModule();
void runOnOperation() override {
auto module = getOperation();
// Walk nested symbols.
SmallVector<FuncOp, 4> deadFunctions;
@ -86,9 +86,10 @@ struct SymbolUsesPass : public ModulePass<SymbolUsesPass> {
/// This is a symbol test pass that tests the symbol use replacement
/// functionality provided by the symbol table.
struct SymbolReplacementPass : public ModulePass<SymbolReplacementPass> {
void runOnModule() override {
auto module = getModule();
struct SymbolReplacementPass
: public OperationPass<SymbolReplacementPass, ModuleOp> {
void runOnOperation() override {
auto module = getOperation();
// Walk nested functions and modules.
module.getBodyRegion().walk([&](Operation *nestedOp) {

View File

@ -13,8 +13,8 @@
using namespace mlir;
namespace {
struct TestModulePass : public ModulePass<TestModulePass> {
void runOnModule() final {}
struct TestModulePass : public OperationPass<TestModulePass, ModuleOp> {
void runOnOperation() final {}
};
struct TestFunctionPass : public FunctionPass<TestFunctionPass> {
void runOnFunction() final {}

View File

@ -18,11 +18,11 @@ using namespace mlir;
namespace {
struct TestAllReduceLoweringPass
: public ModulePass<TestAllReduceLoweringPass> {
void runOnModule() override {
: public OperationPass<TestAllReduceLoweringPass, ModuleOp> {
void runOnOperation() override {
OwningRewritePatternList patterns;
populateGpuRewritePatterns(&getContext(), patterns);
applyPatternsGreedily(getModule(), patterns);
applyPatternsGreedily(getOperation(), patterns);
}
};
} // namespace

View File

@ -17,9 +17,9 @@
using namespace mlir;
namespace {
struct TestCallGraphPass : public ModulePass<TestCallGraphPass> {
void runOnModule() {
llvm::errs() << "Testing : " << getModule().getAttr("test.name") << "\n";
struct TestCallGraphPass : public OperationPass<TestCallGraphPass, ModuleOp> {
void runOnOperation() override {
llvm::errs() << "Testing : " << getOperation().getAttr("test.name") << "\n";
getAnalysis<CallGraph>().print(llvm::errs());
}
};

View File

@ -17,7 +17,7 @@ namespace {
/// It also takes all operations that are not function operations or
/// terminators and clones them with opaque locations which store the initial
/// locations.
struct TestOpaqueLoc : public ModulePass<TestOpaqueLoc> {
struct TestOpaqueLoc : public OperationPass<TestOpaqueLoc, ModuleOp> {
/// A simple structure which is used for testing as an underlying location in
/// OpaqueLoc.
@ -29,11 +29,11 @@ struct TestOpaqueLoc : public ModulePass<TestOpaqueLoc> {
int id;
};
void runOnModule() override {
void runOnOperation() override {
std::vector<std::unique_ptr<MyLocation>> myLocs;
int last_it = 0;
getModule().walk([&](Operation *op) {
getOperation().walk([&](Operation *op) {
myLocs.push_back(std::make_unique<MyLocation>(last_it++));
Location loc = op->getLoc();
@ -74,7 +74,7 @@ struct TestOpaqueLoc : public ModulePass<TestOpaqueLoc> {
os.flush();
});
getModule().walk([&](Operation *op) { op->emitOpError(); });
getOperation().walk([&](Operation *op) { op->emitOpError(); });
}
};