[mlir][python] Allow adding to existing pass manager

This adds a `PassManager.add` method which adds pipeline elements to the
pass manager. This allows for progressively building up a pipeline from
python without string manipulation.

Reviewed By: mehdi_amini

Differential Revision: https://reviews.llvm.org/D137344
This commit is contained in:
rkayaith 2022-10-20 00:27:09 -04:00
parent 2f211f865d
commit dd1b1d4450
3 changed files with 39 additions and 5 deletions

View File

@ -100,6 +100,20 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) {
"Parse a textual pass-pipeline and return a top-level PassManager "
"that can be applied on a Module. Throw a ValueError if the pipeline "
"can't be parsed")
.def(
"add",
[](PyPassManager &passManager, const std::string &pipeline) {
PyPrintAccumulator errorMsg;
MlirLogicalResult status = mlirOpPassManagerAddPipeline(
mlirPassManagerGetAsOpPassManager(passManager.get()),
mlirStringRefCreate(pipeline.data(), pipeline.size()),
errorMsg.getCallback(), errorMsg.getUserData());
if (mlirLogicalResultIsFailure(status))
throw SetPyError(PyExc_ValueError, std::string(errorMsg.join()));
},
py::arg("pipeline"),
"Add textual pipeline elements to the pass manager. Throws a "
"ValueError if the pipeline can't be parsed.")
.def(
"run",
[](PyPassManager &passManager, PyModule &module) {

View File

@ -191,11 +191,17 @@ def transform(module, boilerplate):
ops = module.operation.regions[0].blocks[0].operations
mod = Module.parse("\n".join([str(op) for op in ops]) + boilerplate)
pm = PassManager.parse(
"builtin.module(func.func(convert-linalg-to-loops, lower-affine, " +
"convert-math-to-llvm, convert-scf-to-cf, arith-expand, memref-expand), "
+ "convert-vector-to-llvm, convert-memref-to-llvm, convert-func-to-llvm," +
"reconcile-unrealized-casts)")
pm = PassManager('builtin.module')
pm.add("func.func(convert-linalg-to-loops)")
pm.add("func.func(lower-affine)")
pm.add("func.func(convert-math-to-llvm)")
pm.add("func.func(convert-scf-to-cf)")
pm.add("func.func(arith-expand)")
pm.add("func.func(memref-expand)")
pm.add("convert-vector-to-llvm")
pm.add("convert-memref-to-llvm")
pm.add("convert-func-to-llvm")
pm.add("reconcile-unrealized-casts")
pm.run(mod)
return mod

View File

@ -75,6 +75,20 @@ def testParseFail():
log("Exception not produced")
run(testParseFail)
# Check that adding to a pass manager works
# CHECK-LABEL: TEST: testAdd
@run
def testAdd():
pm = PassManager("any", Context())
# CHECK: pm: 'any()'
log(f"pm: '{pm}'")
# CHECK: pm: 'any(cse)'
pm.add("cse")
log(f"pm: '{pm}'")
# CHECK: pm: 'any(cse,cse)'
pm.add("cse")
log(f"pm: '{pm}'")
# Verify failure on incorrect level of nesting.
# CHECK-LABEL: TEST: testInvalidNesting