[mlir] add OperationType to the Transform dialect

Add a new OperationType handle type to the Transform dialect. This
transform type is parameterized by the name of the payload operation it
can point to. It is intended as a constraint on transformations that are
only applicable to a specific kind of payload operations. If a
transformation is applicable to a small set of operation classes, it can
be wrapped into a transform op by using a disjunctive constraint, such
as `Type<Or<[Transform_ConcreteOperation<"foo">.predicate,
Transform_ConcreteOperation<"bar">.predicate]>>` for its operand without
modifying this type. Broader sets of accepted operations should be
modeled as specific types.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D135586
This commit is contained in:
Alex Zinenko 2022-10-10 14:33:59 +00:00
parent 6bb997c032
commit 3e1f6d02f7
21 changed files with 451 additions and 11 deletions

View File

@ -0,0 +1,46 @@
//===-- mlir-c/Dialect/Transform.h - C API for Transform Dialect --*- C -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM
// Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_C_DIALECT_TRANSFORM_H
#define MLIR_C_DIALECT_TRANSFORM_H
#include "mlir-c/IR.h"
#include "mlir-c/Support.h"
#ifdef __cplusplus
extern "C" {
#endif
MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Transform, transform);
//===---------------------------------------------------------------------===//
// AnyOpType
//===---------------------------------------------------------------------===//
MLIR_CAPI_EXPORTED bool mlirTypeIsATransformAnyOpType(MlirType type);
MLIR_CAPI_EXPORTED MlirType mlirTransformAnyOpTypeGet(MlirContext ctx);
//===---------------------------------------------------------------------===//
// OperationType
//===---------------------------------------------------------------------===//
MLIR_CAPI_EXPORTED bool mlirTypeIsATransformOperationType(MlirType type);
MLIR_CAPI_EXPORTED MlirType
mlirTransformOperationTypeGet(MlirContext ctx, MlirStringRef operationName);
MLIR_CAPI_EXPORTED MlirStringRef
mlirTransformOperationTypeGetOperationName(MlirType type);
#ifdef __cplusplus
}
#endif
#endif // MLIR_C_DIALECT_TRANSFORM_H

View File

@ -359,6 +359,8 @@ def Transform_Dialect : Dialect {
/// mnemonic.
[[noreturn]] void reportDuplicateTypeRegistration(StringRef mnemonic);
void initializeTypes();
template <typename, typename...>
friend class TransformDialectExtension;

View File

@ -104,6 +104,13 @@ def TransformTypeInterface : TypeInterface<"TransformTypeInterface"> {
"::mlir::ArrayRef<::mlir::Operation *>":$payload)
>
];
let extraSharedClassDeclaration = [{
DiagnosedSilenceableFailure emitSilenceableError(Location loc) const {
Diagnostic diag(loc, DiagnosticSeverity::Error);
return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
}
}];
}
def FunctionalStyleTransformOpTrait

View File

@ -23,4 +23,24 @@ def Transform_AnyOpType : TypeDef<Transform_Dialect, "AnyOp",
let assemblyFormat = "";
}
def Transform_OperationType : TypeDef<Transform_Dialect, "Operation",
[DeclareTypeInterfaceMethods<TransformTypeInterface>]> {
let description = [{
Transform IR handle that can be associated with a list of Payload IR
operations with the specified operation name.
}];
let mnemonic = "op";
let parameters = (ins
StringRefParameter<"Name of the allowed payload operation">:$operation_name
);
let assemblyFormat = "`<` $operation_name `>`";
}
class Transform_ConcreteOpType<string opname>
: Type<And<[Transform_OperationType.predicate,
CPred<"$_self.cast<::mlir::transform::OperationType>()"
".getOperationName() == \"" # opname # "\"">]>,
"Transform IR handle to " # opname # " operations",
"::mlir::transform::OperationType">;
#endif // MLIR_DIALECT_TRANSFORM_IR_TRANSFORMTYPES

View File

@ -0,0 +1,64 @@
//===- DialectTransform.cpp - 'transform' dialect submodule ---------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir-c/Dialect/Transform.h"
#include "mlir-c/IR.h"
#include "mlir-c/Support.h"
#include "mlir/Bindings/Python/PybindAdaptors.h"
namespace py = pybind11;
using namespace mlir;
using namespace mlir::python;
using namespace mlir::python::adaptors;
void populateDialectTransformSubmodule(const pybind11::module &m) {
//===-------------------------------------------------------------------===//
// AnyOpType
//===-------------------------------------------------------------------===//
auto anyOpType =
mlir_type_subclass(m, "AnyOpType", mlirTypeIsATransformAnyOpType);
anyOpType.def_classmethod(
"get",
[](py::object cls, MlirContext ctx) {
return cls(mlirTransformAnyOpTypeGet(ctx));
},
"Get an instance of AnyOpType in the given context.", py::arg("cls"),
py::arg("context") = py::none());
//===-------------------------------------------------------------------===//
// OperationType
//===-------------------------------------------------------------------===//
auto operationType =
mlir_type_subclass(m, "OperationType", mlirTypeIsATransformOperationType);
operationType.def_classmethod(
"get",
[](py::object cls, const std::string &operationName, MlirContext ctx) {
MlirStringRef cOperationName =
mlirStringRefCreate(operationName.data(), operationName.size());
return cls(mlirTransformOperationTypeGet(ctx, cOperationName));
},
"Get an instance of OperationType for the given kind in the given "
"context",
py::arg("cls"), py::arg("operation_name"),
py::arg("context") = py::none());
operationType.def_property_readonly(
"operation_name",
[](MlirType type) {
MlirStringRef operationName =
mlirTransformOperationTypeGetOperationName(type);
return py::str(operationName.data, operationName.length);
},
"Get the name of the payload operation accepted by the handle.");
}
PYBIND11_MODULE(_mlirDialectsTransform, m) {
m.doc() = "MLIR Transform dialect.";
populateDialectTransformSubmodule(m);
}

View File

@ -116,6 +116,15 @@ add_mlir_upstream_c_api_library(MLIRCAPITensor
MLIRTensorDialect
)
add_mlir_upstream_c_api_library(MLIRCAPITransformDialect
Transform.cpp
PARTIAL_SOURCES_INTENDED
LINK_LIBS PUBLIC
MLIRCAPIIR
MLIRTransformDialect
)
add_mlir_upstream_c_api_library(MLIRCAPIQuant
Quant.cpp

View File

@ -0,0 +1,48 @@
//===- Transform.cpp - C Interface for Transform dialect ------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir-c/Dialect/Transform.h"
#include "mlir-c/Support.h"
#include "mlir/CAPI/Registration.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/TransformTypes.h"
using namespace mlir;
MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Transform, transform,
transform::TransformDialect)
//===---------------------------------------------------------------------===//
// AnyOpType
//===---------------------------------------------------------------------===//
bool mlirTypeIsATransformAnyOpType(MlirType type) {
return unwrap(type).isa<transform::AnyOpType>();
}
MlirType mlirTransformAnyOpTypeGet(MlirContext ctx) {
return wrap(transform::AnyOpType::get(unwrap(ctx)));
}
//===---------------------------------------------------------------------===//
// OperationType
//===---------------------------------------------------------------------===//
bool mlirTypeIsATransformOperationType(MlirType type) {
return unwrap(type).isa<transform::OperationType>();
}
MlirType mlirTransformOperationTypeGet(MlirContext ctx,
MlirStringRef operationName) {
return wrap(
transform::OperationType::get(unwrap(ctx), unwrap(operationName)));
}
MlirStringRef mlirTransformOperationTypeGetOperationName(MlirType type) {
return wrap(unwrap(type).cast<transform::OperationType>().getOperationName());
}

View File

@ -60,10 +60,7 @@ void transform::TransformDialect::initialize() {
#define GET_OP_LIST
#include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc"
>();
addTypesChecked<
#define GET_TYPEDEF_LIST
#include "mlir/Dialect/Transform/IR/TransformTypes.cpp.inc"
>();
initializeTypes();
pdl::OperationType::attachInterface<
PDLOperationTypeTransformTypeInterfaceImpl>(*getContext());

View File

@ -30,8 +30,31 @@ generatedTypePrinter(Type def, AsmPrinter &printer);
#define GET_TYPEDEF_CLASSES
#include "mlir/Dialect/Transform/IR/TransformTypes.cpp.inc"
void transform::TransformDialect::initializeTypes() {
addTypesChecked<
#define GET_TYPEDEF_LIST
#include "mlir/Dialect/Transform/IR/TransformTypes.cpp.inc"
>();
}
DiagnosedSilenceableFailure
transform::AnyOpType::checkPayload(Location loc,
ArrayRef<Operation *> payload) const {
return DiagnosedSilenceableFailure::success();
}
DiagnosedSilenceableFailure
transform::OperationType::checkPayload(Location loc,
ArrayRef<Operation *> payload) const {
OperationName opName(getOperationName(), loc.getContext());
for (Operation *op : payload) {
if (opName != op->getName()) {
DiagnosedSilenceableFailure diag =
emitSilenceableError(loc) << "incompatible payload operation name";
diag.attachNote(op->getLoc()) << "payload operation";
return diag;
}
}
return DiagnosedSilenceableFailure::success();
}

View File

@ -121,6 +121,7 @@ declare_mlir_dialect_python_bindings(
SOURCES
dialects/_transform_ops_ext.py
dialects/transform/__init__.py
_mlir_libs/_mlir/dialects/transform/__init__.pyi
DIALECT_NAME transform)
declare_mlir_dialect_extension_python_bindings(
@ -353,6 +354,19 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.SparseTensor.Pybind
MLIRCAPISparseTensor
)
declare_mlir_python_extension(MLIRPythonExtension.Dialects.Transform.Pybind
MODULE_NAME _mlirDialectsTransform
ADD_TO_PARENT MLIRPythonSources.Dialects.transform
ROOT_DIR "${PYTHON_SOURCE_DIR}"
SOURCES
DialectTransform.cpp
PRIVATE_LINK_LIBS
LLVMSupport
EMBED_CAPI_LINK_LIBS
MLIRCAPIIR
MLIRCAPITransformDialect
)
declare_mlir_python_extension(MLIRPythonExtension.AsyncDialectPasses
MODULE_NAME _mlirAsyncPasses
ADD_TO_PARENT MLIRPythonSources.Dialects.async_dialect

View File

@ -0,0 +1,26 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from typing import Optional
from mlir.ir import Type, Context
class AnyOpType(Type):
@staticmethod
def isinstance(type: Type) -> bool: ...
@staticmethod
def get(context: Optional[Context] = None) -> AnyOpType: ...
class OperationType(Type):
@staticmethod
def isinstance(type: Type) -> bool: ...
@staticmethod
def get(operation_name: str, context: Optional[Context] = None) -> OperationType: ...
@property
def operation_name(self) -> str: ...

View File

@ -18,6 +18,16 @@ def _get_symbol_ref_attr(value: Union[Attribute, str]):
return FlatSymbolRefAttr.get(value)
class CastOp:
def __init__(self, result_type: Type, target: Union[Operation, Value], *, loc=None, ip=None):
super().__init__(
result_type,
_get_op_result_or_value(target),
loc=loc,
ip=ip)
class GetClosestIsolatedParentOp:
def __init__(self, result_type: Type, target: Union[Operation, Value], *, loc=None, ip=None):

View File

@ -18,3 +18,4 @@ class FailurePropagationMode(Enum):
return 2
from .._transform_ops_gen import *
from ..._mlir_libs._mlirDialectsTransform import *

View File

@ -54,6 +54,14 @@ _add_capi_test_executable(mlir-capi-pass-test
MLIRCAPITransforms
)
_add_capi_test_executable(mlir-capi-pdl-test
pdl.c
LINK_LIBS PRIVATE
MLIRCAPIIR
MLIRCAPIRegisterEverything
MLIRCAPIPDL
)
_add_capi_test_executable(mlir-capi-sparse-tensor-test
sparse_tensor.c
LINK_LIBS PRIVATE
@ -70,10 +78,10 @@ _add_capi_test_executable(mlir-capi-quant-test
MLIRCAPIQuant
)
_add_capi_test_executable(mlir-capi-pdl-test
pdl.c
_add_capi_test_executable(mlir-capi-transform-test
transform.c
LINK_LIBS PRIVATE
MLIRCAPIIR
MLIRCAPIRegisterEverything
MLIRCAPIPDL
MLIRCAPITransformDialect
)

View File

@ -0,0 +1,88 @@
//===- transform.c - Test of Transform dialect C API ----------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM
// Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
// RUN: mlir-capi-transform-test 2>&1 | FileCheck %s
#include "mlir-c/Dialect/Transform.h"
#include "mlir-c/IR.h"
#include "mlir-c/Support.h"
#include <assert.h>
#include <stdio.h>
#include <stdlib.h>
// CHECK-LABEL: testAnyOpType
void testAnyOpType(MlirContext ctx) {
fprintf(stderr, "testAnyOpType\n");
MlirType parsedType = mlirTypeParseGet(
ctx, mlirStringRefCreateFromCString("!transform.any_op"));
MlirType constructedType = mlirTransformAnyOpTypeGet(ctx);
assert(!mlirTypeIsNull(parsedType) && "couldn't parse AnyOpType");
assert(!mlirTypeIsNull(constructedType) && "couldn't construct AnyOpType");
// CHECK: equal: 1
fprintf(stderr, "equal: %d\n", mlirTypeEqual(parsedType, constructedType));
// CHECK: parsedType isa AnyOpType: 1
fprintf(stderr, "parsedType isa AnyOpType: %d\n",
mlirTypeIsATransformAnyOpType(parsedType));
// CHECK: parsedType isa OperationType: 0
fprintf(stderr, "parsedType isa OperationType: %d\n",
mlirTypeIsATransformOperationType(parsedType));
// CHECK: !transform.any_op
mlirTypeDump(constructedType);
fprintf(stderr, "\n\n");
}
// CHECK-LABEL: testOperationType
void testOperationType(MlirContext ctx) {
fprintf(stderr, "testOperationType\n");
MlirType parsedType = mlirTypeParseGet(
ctx, mlirStringRefCreateFromCString("!transform.op<\"foo.bar\">"));
MlirType constructedType = mlirTransformOperationTypeGet(
ctx, mlirStringRefCreateFromCString("foo.bar"));
assert(!mlirTypeIsNull(parsedType) && "couldn't parse AnyOpType");
assert(!mlirTypeIsNull(constructedType) && "couldn't construct AnyOpType");
// CHECK: equal: 1
fprintf(stderr, "equal: %d\n", mlirTypeEqual(parsedType, constructedType));
// CHECK: parsedType isa AnyOpType: 0
fprintf(stderr, "parsedType isa AnyOpType: %d\n",
mlirTypeIsATransformAnyOpType(parsedType));
// CHECK: parsedType isa OperationType: 1
fprintf(stderr, "parsedType isa OperationType: %d\n",
mlirTypeIsATransformOperationType(parsedType));
// CHECK: operation name equal: 1
MlirStringRef operationName =
mlirTransformOperationTypeGetOperationName(constructedType);
fprintf(stderr, "operation name equal: %d\n",
mlirStringRefEqual(operationName,
mlirStringRefCreateFromCString("foo.bar")));
// CHECK: !transform.op<"foo.bar">
mlirTypeDump(constructedType);
fprintf(stderr, "\n\n");
}
int main(void) {
MlirContext ctx = mlirContextCreate();
mlirDialectHandleRegisterDialect(mlirGetDialectHandle__transform__(), ctx);
testAnyOpType(ctx);
testOperationType(ctx);
return EXIT_SUCCESS;
}

View File

@ -70,9 +70,10 @@ set(MLIR_TEST_DEPENDS
mlir-capi-ir-test
mlir-capi-llvm-test
mlir-capi-pass-test
mlir-capi-sparse-tensor-test
mlir-capi-quant-test
mlir-capi-pdl-test
mlir-capi-quant-test
mlir-capi-sparse-tensor-test
mlir-capi-transform-test
mlir-linalg-ods-yaml-gen
mlir-lsp-server
mlir-pdll-lsp-server

View File

@ -64,4 +64,6 @@ transform.sequence failures(propagate) {
^bb0(%arg0: !pdl.operation):
// CHECK: cast %{{.*}} : !pdl.operation to !transform.any_op
%0 = cast %arg0: !pdl.operation to !transform.any_op
// CHECK: cast %{{.*}} : !transform.any_op to !transform.op<"builtin.module">
%1 = cast %0: !transform.any_op to !transform.op<"builtin.module">
}

View File

@ -841,3 +841,45 @@ transform.with_pdl_patterns {
transform.cast %2 : !transform.test_dialect_op to !pdl.operation
}
}
// -----
"test.some_op"() : () -> ()
"other_dialect.other_op"() : () -> ()
transform.with_pdl_patterns {
^bb0(%arg0: !pdl.operation):
pdl.pattern @some : benefit(1) {
%0 = pdl.operation "test.some_op"
pdl.rewrite %0 with "transform.dialect"
}
sequence %arg0 : !pdl.operation failures(propagate) {
^bb1(%arg1: !pdl.operation):
%0 = pdl_match @some in %arg1 : (!pdl.operation) -> !pdl.operation
%2 = transform.cast %0 : !pdl.operation to !transform.op<"test.some_op">
transform.cast %2 : !transform.op<"test.some_op"> to !pdl.operation
}
}
// -----
"test.some_op"() : () -> ()
// expected-note @below {{payload operation}}
"other_dialect.other_op"() : () -> ()
transform.with_pdl_patterns {
^bb0(%arg0: !pdl.operation):
pdl.pattern @other : benefit(1) {
%0 = pdl.operation "other_dialect.other_op"
pdl.rewrite %0 with "transform.dialect"
}
sequence %arg0 : !pdl.operation failures(propagate) {
^bb1(%arg1: !pdl.operation):
%0 = pdl_match @other in %arg1 : (!pdl.operation) -> !pdl.operation
// expected-error @below {{incompatible payload operation name}}
%2 = transform.cast %0 : !pdl.operation to !transform.op<"test.some_op">
transform.cast %2 : !transform.op<"test.some_op"> to !pdl.operation
}
}

View File

@ -60,9 +60,10 @@ tools = [
'mlir-capi-ir-test',
'mlir-capi-llvm-test',
'mlir-capi-pass-test',
'mlir-capi-sparse-tensor-test',
'mlir-capi-quant-test',
'mlir-capi-pdl-test',
'mlir-capi-quant-test',
'mlir-capi-sparse-tensor-test',
'mlir-capi-transform-test',
'mlir-cpu-runner',
'mlir-linalg-ods-yaml-gen',
'mlir-reduce',

View File

@ -15,6 +15,20 @@ def run(f):
return f
@run
def testTypes():
# CHECK-LABEL: TEST: testTypes
# CHECK: !transform.any_op
any_op = transform.AnyOpType.get()
print(any_op)
# CHECK: !transform.op<"foo.bar">
# CHECK: foo.bar
concrete_op = transform.OperationType.get("foo.bar")
print(concrete_op)
print(concrete_op.operation_name)
@run
def testSequenceOp():
sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE,

View File

@ -565,6 +565,23 @@ mlir_c_api_cc_library(
],
)
mlir_c_api_cc_library(
name = "CAPITransformDialect",
srcs = [
"lib/CAPI/Dialect/Transform.cpp",
],
hdrs = [
"include/mlir-c/Dialect/Transform.h",
],
capi_deps = [
":CAPIIR",
],
includes = ["include"],
deps = [
":TransformDialect",
],
)
mlir_c_api_cc_library(
name = "CAPIMLProgram",
srcs = [