[mlir][python] Add python support for async dialect and passes.

since the `async` keyword is reserved in python, the dialect is called async_dialect.

Differential Revision: https://reviews.llvm.org/D101447
This commit is contained in:
Nicolas Vasilache 2021-04-27 19:57:56 +00:00
parent 262c679d32
commit e7db8408d0
11 changed files with 165 additions and 0 deletions

View File

@ -0,0 +1,28 @@
//===-- mlir-c/Dialect/Async.h - C API for Async dialect ---------*- C -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM
// Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===---------------------------------------------------------------------===//
#ifndef MLIR_C_DIALECT_ASYNC_H
#define MLIR_C_DIALECT_ASYNC_H
#include "mlir-c/Registration.h"
#include "mlir-c/Support.h"
#ifdef __cplusplus
extern "C" {
#endif
MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Async, async);
#ifdef __cplusplus
}
#endif
#include "mlir/Dialect/Async/Passes.capi.h.inc"
#endif // MLIR_C_DIALECT_ASYNC_H

View File

@ -2,6 +2,8 @@ add_subdirectory(IR)
set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls -name Async)
mlir_tablegen(Passes.capi.h.inc -gen-pass-capi-header --prefix Async)
mlir_tablegen(Passes.capi.cpp.inc -gen-pass-capi-impl --prefix Async)
add_public_tablegen_target(MLIRAsyncPassIncGen)
add_mlir_doc(Passes AsyncPasses ./ -gen-pass-doc)

View File

@ -0,0 +1,15 @@
//===-- AsyncOps.td - Entry point async_dialect bindings --*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===---------------------------------------------------------------------===//
#ifndef PYTHON_BINDINGS_ASYNC_OPS
#define PYTHON_BINDINGS_ASYNC_OPS
include "mlir/Bindings/Python/Attributes.td"
include "mlir/Dialect/Async/IR/AsyncOps.td"
#endif

View File

@ -0,0 +1,22 @@
//===- AsyncPasses.cpp - Pybind module for the Async passes -------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir-c/Dialect/Async.h"
#include <pybind11/pybind11.h>
// -----------------------------------------------------------------------------
// Module initialization.
// -----------------------------------------------------------------------------
PYBIND11_MODULE(_mlirAsyncPasses, m) {
m.doc() = "MLIR Async Dialect Passes";
// Register all Async passes on load.
mlirRegisterAsyncPasses();
}

View File

@ -31,6 +31,11 @@ endforeach()
# Generate dialect-specific bindings.
################################################################################
add_mlir_dialect_python_bindings(MLIRBindingsPythonAsyncOps
TD_FILE AsyncOps.td
DIALECT_NAME async_dialect)
add_dependencies(MLIRBindingsPythonSources MLIRBindingsPythonAsyncOps)
add_mlir_dialect_python_bindings(MLIRBindingsPythonBuiltinOps
TD_FILE BuiltinOps.td
DIALECT_NAME builtin)
@ -120,6 +125,14 @@ endif()
add_subdirectory(Transforms)
add_subdirectory(Conversions)
add_mlir_python_extension(MLIRAsyncPassesBindingsPythonExtension _mlirAsyncPasses
INSTALL_DIR
python
SOURCES
AsyncPasses.cpp
)
add_dependencies(MLIRBindingsPythonExtension MLIRAsyncPassesBindingsPythonExtension)
add_mlir_python_extension(MLIRLinalgPassesBindingsPythonExtension _mlirLinalgPasses
INSTALL_DIR
python

View File

@ -0,0 +1,5 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from .._async_dialect_ops_gen import *

View File

@ -0,0 +1,6 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from ...._cext_loader import _load_extension
_cextAsyncPasses = _load_extension("_mlirAsyncPasses")

View File

@ -0,0 +1,13 @@
//===- Async.cpp - C Interface for Async dialect --------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Async/IR/Async.h"
#include "mlir-c/Dialect/Async.h"
#include "mlir/CAPI/Registration.h"
MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Async, async, mlir::async::AsyncDialect)

View File

@ -0,0 +1,26 @@
//===- AsyncPasses.cpp - C API for Async Dialect Passes -----------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/CAPI/Pass.h"
#include "mlir/Dialect/Async/Passes.h"
#include "mlir/Pass/Pass.h"
// Must include the declarations as they carry important visibility attributes.
#include "mlir/Dialect/Async/Passes.capi.h.inc"
using namespace mlir;
#ifdef __cplusplus
extern "C" {
#endif
#include "mlir/Dialect/Async/Passes.capi.cpp.inc"
#ifdef __cplusplus
}
#endif

View File

@ -1,5 +1,7 @@
# TODO: Make the check source feature optional as an argument on *_add_library.
set(LLVM_OPTIONAL_SOURCES
Async.cpp
AsyncPasses.cpp
Linalg.cpp
LinalgPasses.cpp
SCF.cpp
@ -8,6 +10,20 @@ set(LLVM_OPTIONAL_SOURCES
Tensor.cpp
)
add_mlir_public_c_api_library(MLIRCAPIAsync
Async.cpp
AsyncPasses.cpp
DEPENDS
MLIRAsyncPassIncGen
LINK_LIBS PUBLIC
MLIRCAPIIR
MLIRAsync
MLIRAsyncTransforms
MLIRPass
)
add_mlir_public_c_api_library(MLIRCAPILinalg
Linalg.cpp
LinalgPasses.cpp

View File

@ -0,0 +1,19 @@
# RUN: %PYTHON %s | FileCheck %s
from mlir.ir import *
import mlir.dialects.async_dialect
import mlir.dialects.async_dialect.passes
from mlir.passmanager import *
def run(f):
print("\nTEST:", f.__name__)
f()
def testAsyncPass():
with Context() as context:
PassManager.parse('async-to-async-runtime')
print('SUCCESS')
# CHECK-LABEL: testAsyncPass
# CHECK: SUCCESS
run(testAsyncPass)