[mlir][python] Allow adding to existing pass manager
This adds a `PassManager.add` method which adds pipeline elements to the pass manager. This allows for progressively building up a pipeline from python without string manipulation. Reviewed By: mehdi_amini Differential Revision: https://reviews.llvm.org/D137344
This commit is contained in:
parent
2f211f865d
commit
dd1b1d4450
|
@ -100,6 +100,20 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) {
|
|||
"Parse a textual pass-pipeline and return a top-level PassManager "
|
||||
"that can be applied on a Module. Throw a ValueError if the pipeline "
|
||||
"can't be parsed")
|
||||
.def(
|
||||
"add",
|
||||
[](PyPassManager &passManager, const std::string &pipeline) {
|
||||
PyPrintAccumulator errorMsg;
|
||||
MlirLogicalResult status = mlirOpPassManagerAddPipeline(
|
||||
mlirPassManagerGetAsOpPassManager(passManager.get()),
|
||||
mlirStringRefCreate(pipeline.data(), pipeline.size()),
|
||||
errorMsg.getCallback(), errorMsg.getUserData());
|
||||
if (mlirLogicalResultIsFailure(status))
|
||||
throw SetPyError(PyExc_ValueError, std::string(errorMsg.join()));
|
||||
},
|
||||
py::arg("pipeline"),
|
||||
"Add textual pipeline elements to the pass manager. Throws a "
|
||||
"ValueError if the pipeline can't be parsed.")
|
||||
.def(
|
||||
"run",
|
||||
[](PyPassManager &passManager, PyModule &module) {
|
||||
|
|
|
@ -191,11 +191,17 @@ def transform(module, boilerplate):
|
|||
ops = module.operation.regions[0].blocks[0].operations
|
||||
mod = Module.parse("\n".join([str(op) for op in ops]) + boilerplate)
|
||||
|
||||
pm = PassManager.parse(
|
||||
"builtin.module(func.func(convert-linalg-to-loops, lower-affine, " +
|
||||
"convert-math-to-llvm, convert-scf-to-cf, arith-expand, memref-expand), "
|
||||
+ "convert-vector-to-llvm, convert-memref-to-llvm, convert-func-to-llvm," +
|
||||
"reconcile-unrealized-casts)")
|
||||
pm = PassManager('builtin.module')
|
||||
pm.add("func.func(convert-linalg-to-loops)")
|
||||
pm.add("func.func(lower-affine)")
|
||||
pm.add("func.func(convert-math-to-llvm)")
|
||||
pm.add("func.func(convert-scf-to-cf)")
|
||||
pm.add("func.func(arith-expand)")
|
||||
pm.add("func.func(memref-expand)")
|
||||
pm.add("convert-vector-to-llvm")
|
||||
pm.add("convert-memref-to-llvm")
|
||||
pm.add("convert-func-to-llvm")
|
||||
pm.add("reconcile-unrealized-casts")
|
||||
pm.run(mod)
|
||||
return mod
|
||||
|
||||
|
|
|
@ -75,6 +75,20 @@ def testParseFail():
|
|||
log("Exception not produced")
|
||||
run(testParseFail)
|
||||
|
||||
# Check that adding to a pass manager works
|
||||
# CHECK-LABEL: TEST: testAdd
|
||||
@run
|
||||
def testAdd():
|
||||
pm = PassManager("any", Context())
|
||||
# CHECK: pm: 'any()'
|
||||
log(f"pm: '{pm}'")
|
||||
# CHECK: pm: 'any(cse)'
|
||||
pm.add("cse")
|
||||
log(f"pm: '{pm}'")
|
||||
# CHECK: pm: 'any(cse,cse)'
|
||||
pm.add("cse")
|
||||
log(f"pm: '{pm}'")
|
||||
|
||||
|
||||
# Verify failure on incorrect level of nesting.
|
||||
# CHECK-LABEL: TEST: testInvalidNesting
|
||||
|
|
Loading…
Reference in New Issue