[mlir] add OperationType to the Transform dialect
Add a new OperationType handle type to the Transform dialect. This transform type is parameterized by the name of the payload operation it can point to. It is intended as a constraint on transformations that are only applicable to a specific kind of payload operations. If a transformation is applicable to a small set of operation classes, it can be wrapped into a transform op by using a disjunctive constraint, such as `Type<Or<[Transform_ConcreteOperation<"foo">.predicate, Transform_ConcreteOperation<"bar">.predicate]>>` for its operand without modifying this type. Broader sets of accepted operations should be modeled as specific types. Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D135586
This commit is contained in:
parent
6bb997c032
commit
3e1f6d02f7
|
@ -0,0 +1,46 @@
|
|||
//===-- mlir-c/Dialect/Transform.h - C API for Transform Dialect --*- C -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM
|
||||
// Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_C_DIALECT_TRANSFORM_H
|
||||
#define MLIR_C_DIALECT_TRANSFORM_H
|
||||
|
||||
#include "mlir-c/IR.h"
|
||||
#include "mlir-c/Support.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Transform, transform);
|
||||
|
||||
//===---------------------------------------------------------------------===//
|
||||
// AnyOpType
|
||||
//===---------------------------------------------------------------------===//
|
||||
|
||||
MLIR_CAPI_EXPORTED bool mlirTypeIsATransformAnyOpType(MlirType type);
|
||||
|
||||
MLIR_CAPI_EXPORTED MlirType mlirTransformAnyOpTypeGet(MlirContext ctx);
|
||||
|
||||
//===---------------------------------------------------------------------===//
|
||||
// OperationType
|
||||
//===---------------------------------------------------------------------===//
|
||||
|
||||
MLIR_CAPI_EXPORTED bool mlirTypeIsATransformOperationType(MlirType type);
|
||||
|
||||
MLIR_CAPI_EXPORTED MlirType
|
||||
mlirTransformOperationTypeGet(MlirContext ctx, MlirStringRef operationName);
|
||||
|
||||
MLIR_CAPI_EXPORTED MlirStringRef
|
||||
mlirTransformOperationTypeGetOperationName(MlirType type);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif // MLIR_C_DIALECT_TRANSFORM_H
|
|
@ -359,6 +359,8 @@ def Transform_Dialect : Dialect {
|
|||
/// mnemonic.
|
||||
[[noreturn]] void reportDuplicateTypeRegistration(StringRef mnemonic);
|
||||
|
||||
void initializeTypes();
|
||||
|
||||
template <typename, typename...>
|
||||
friend class TransformDialectExtension;
|
||||
|
||||
|
|
|
@ -104,6 +104,13 @@ def TransformTypeInterface : TypeInterface<"TransformTypeInterface"> {
|
|||
"::mlir::ArrayRef<::mlir::Operation *>":$payload)
|
||||
>
|
||||
];
|
||||
|
||||
let extraSharedClassDeclaration = [{
|
||||
DiagnosedSilenceableFailure emitSilenceableError(Location loc) const {
|
||||
Diagnostic diag(loc, DiagnosticSeverity::Error);
|
||||
return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def FunctionalStyleTransformOpTrait
|
||||
|
|
|
@ -23,4 +23,24 @@ def Transform_AnyOpType : TypeDef<Transform_Dialect, "AnyOp",
|
|||
let assemblyFormat = "";
|
||||
}
|
||||
|
||||
def Transform_OperationType : TypeDef<Transform_Dialect, "Operation",
|
||||
[DeclareTypeInterfaceMethods<TransformTypeInterface>]> {
|
||||
let description = [{
|
||||
Transform IR handle that can be associated with a list of Payload IR
|
||||
operations with the specified operation name.
|
||||
}];
|
||||
let mnemonic = "op";
|
||||
let parameters = (ins
|
||||
StringRefParameter<"Name of the allowed payload operation">:$operation_name
|
||||
);
|
||||
let assemblyFormat = "`<` $operation_name `>`";
|
||||
}
|
||||
|
||||
class Transform_ConcreteOpType<string opname>
|
||||
: Type<And<[Transform_OperationType.predicate,
|
||||
CPred<"$_self.cast<::mlir::transform::OperationType>()"
|
||||
".getOperationName() == \"" # opname # "\"">]>,
|
||||
"Transform IR handle to " # opname # " operations",
|
||||
"::mlir::transform::OperationType">;
|
||||
|
||||
#endif // MLIR_DIALECT_TRANSFORM_IR_TRANSFORMTYPES
|
||||
|
|
|
@ -0,0 +1,64 @@
|
|||
//===- DialectTransform.cpp - 'transform' dialect submodule ---------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir-c/Dialect/Transform.h"
|
||||
#include "mlir-c/IR.h"
|
||||
#include "mlir-c/Support.h"
|
||||
#include "mlir/Bindings/Python/PybindAdaptors.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
using namespace mlir;
|
||||
using namespace mlir::python;
|
||||
using namespace mlir::python::adaptors;
|
||||
|
||||
void populateDialectTransformSubmodule(const pybind11::module &m) {
|
||||
//===-------------------------------------------------------------------===//
|
||||
// AnyOpType
|
||||
//===-------------------------------------------------------------------===//
|
||||
|
||||
auto anyOpType =
|
||||
mlir_type_subclass(m, "AnyOpType", mlirTypeIsATransformAnyOpType);
|
||||
anyOpType.def_classmethod(
|
||||
"get",
|
||||
[](py::object cls, MlirContext ctx) {
|
||||
return cls(mlirTransformAnyOpTypeGet(ctx));
|
||||
},
|
||||
"Get an instance of AnyOpType in the given context.", py::arg("cls"),
|
||||
py::arg("context") = py::none());
|
||||
|
||||
//===-------------------------------------------------------------------===//
|
||||
// OperationType
|
||||
//===-------------------------------------------------------------------===//
|
||||
|
||||
auto operationType =
|
||||
mlir_type_subclass(m, "OperationType", mlirTypeIsATransformOperationType);
|
||||
operationType.def_classmethod(
|
||||
"get",
|
||||
[](py::object cls, const std::string &operationName, MlirContext ctx) {
|
||||
MlirStringRef cOperationName =
|
||||
mlirStringRefCreate(operationName.data(), operationName.size());
|
||||
return cls(mlirTransformOperationTypeGet(ctx, cOperationName));
|
||||
},
|
||||
"Get an instance of OperationType for the given kind in the given "
|
||||
"context",
|
||||
py::arg("cls"), py::arg("operation_name"),
|
||||
py::arg("context") = py::none());
|
||||
operationType.def_property_readonly(
|
||||
"operation_name",
|
||||
[](MlirType type) {
|
||||
MlirStringRef operationName =
|
||||
mlirTransformOperationTypeGetOperationName(type);
|
||||
return py::str(operationName.data, operationName.length);
|
||||
},
|
||||
"Get the name of the payload operation accepted by the handle.");
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(_mlirDialectsTransform, m) {
|
||||
m.doc() = "MLIR Transform dialect.";
|
||||
populateDialectTransformSubmodule(m);
|
||||
}
|
|
@ -116,6 +116,15 @@ add_mlir_upstream_c_api_library(MLIRCAPITensor
|
|||
MLIRTensorDialect
|
||||
)
|
||||
|
||||
add_mlir_upstream_c_api_library(MLIRCAPITransformDialect
|
||||
Transform.cpp
|
||||
|
||||
PARTIAL_SOURCES_INTENDED
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRCAPIIR
|
||||
MLIRTransformDialect
|
||||
)
|
||||
|
||||
add_mlir_upstream_c_api_library(MLIRCAPIQuant
|
||||
Quant.cpp
|
||||
|
||||
|
|
|
@ -0,0 +1,48 @@
|
|||
//===- Transform.cpp - C Interface for Transform dialect ------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir-c/Dialect/Transform.h"
|
||||
#include "mlir-c/Support.h"
|
||||
#include "mlir/CAPI/Registration.h"
|
||||
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
|
||||
#include "mlir/Dialect/Transform/IR/TransformTypes.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Transform, transform,
|
||||
transform::TransformDialect)
|
||||
|
||||
//===---------------------------------------------------------------------===//
|
||||
// AnyOpType
|
||||
//===---------------------------------------------------------------------===//
|
||||
|
||||
bool mlirTypeIsATransformAnyOpType(MlirType type) {
|
||||
return unwrap(type).isa<transform::AnyOpType>();
|
||||
}
|
||||
|
||||
MlirType mlirTransformAnyOpTypeGet(MlirContext ctx) {
|
||||
return wrap(transform::AnyOpType::get(unwrap(ctx)));
|
||||
}
|
||||
|
||||
//===---------------------------------------------------------------------===//
|
||||
// OperationType
|
||||
//===---------------------------------------------------------------------===//
|
||||
|
||||
bool mlirTypeIsATransformOperationType(MlirType type) {
|
||||
return unwrap(type).isa<transform::OperationType>();
|
||||
}
|
||||
|
||||
MlirType mlirTransformOperationTypeGet(MlirContext ctx,
|
||||
MlirStringRef operationName) {
|
||||
return wrap(
|
||||
transform::OperationType::get(unwrap(ctx), unwrap(operationName)));
|
||||
}
|
||||
|
||||
MlirStringRef mlirTransformOperationTypeGetOperationName(MlirType type) {
|
||||
return wrap(unwrap(type).cast<transform::OperationType>().getOperationName());
|
||||
}
|
|
@ -60,10 +60,7 @@ void transform::TransformDialect::initialize() {
|
|||
#define GET_OP_LIST
|
||||
#include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc"
|
||||
>();
|
||||
addTypesChecked<
|
||||
#define GET_TYPEDEF_LIST
|
||||
#include "mlir/Dialect/Transform/IR/TransformTypes.cpp.inc"
|
||||
>();
|
||||
initializeTypes();
|
||||
|
||||
pdl::OperationType::attachInterface<
|
||||
PDLOperationTypeTransformTypeInterfaceImpl>(*getContext());
|
||||
|
|
|
@ -30,8 +30,31 @@ generatedTypePrinter(Type def, AsmPrinter &printer);
|
|||
#define GET_TYPEDEF_CLASSES
|
||||
#include "mlir/Dialect/Transform/IR/TransformTypes.cpp.inc"
|
||||
|
||||
void transform::TransformDialect::initializeTypes() {
|
||||
addTypesChecked<
|
||||
#define GET_TYPEDEF_LIST
|
||||
#include "mlir/Dialect/Transform/IR/TransformTypes.cpp.inc"
|
||||
>();
|
||||
}
|
||||
|
||||
DiagnosedSilenceableFailure
|
||||
transform::AnyOpType::checkPayload(Location loc,
|
||||
ArrayRef<Operation *> payload) const {
|
||||
return DiagnosedSilenceableFailure::success();
|
||||
}
|
||||
|
||||
DiagnosedSilenceableFailure
|
||||
transform::OperationType::checkPayload(Location loc,
|
||||
ArrayRef<Operation *> payload) const {
|
||||
OperationName opName(getOperationName(), loc.getContext());
|
||||
for (Operation *op : payload) {
|
||||
if (opName != op->getName()) {
|
||||
DiagnosedSilenceableFailure diag =
|
||||
emitSilenceableError(loc) << "incompatible payload operation name";
|
||||
diag.attachNote(op->getLoc()) << "payload operation";
|
||||
return diag;
|
||||
}
|
||||
}
|
||||
|
||||
return DiagnosedSilenceableFailure::success();
|
||||
}
|
||||
|
|
|
@ -121,6 +121,7 @@ declare_mlir_dialect_python_bindings(
|
|||
SOURCES
|
||||
dialects/_transform_ops_ext.py
|
||||
dialects/transform/__init__.py
|
||||
_mlir_libs/_mlir/dialects/transform/__init__.pyi
|
||||
DIALECT_NAME transform)
|
||||
|
||||
declare_mlir_dialect_extension_python_bindings(
|
||||
|
@ -353,6 +354,19 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.SparseTensor.Pybind
|
|||
MLIRCAPISparseTensor
|
||||
)
|
||||
|
||||
declare_mlir_python_extension(MLIRPythonExtension.Dialects.Transform.Pybind
|
||||
MODULE_NAME _mlirDialectsTransform
|
||||
ADD_TO_PARENT MLIRPythonSources.Dialects.transform
|
||||
ROOT_DIR "${PYTHON_SOURCE_DIR}"
|
||||
SOURCES
|
||||
DialectTransform.cpp
|
||||
PRIVATE_LINK_LIBS
|
||||
LLVMSupport
|
||||
EMBED_CAPI_LINK_LIBS
|
||||
MLIRCAPIIR
|
||||
MLIRCAPITransformDialect
|
||||
)
|
||||
|
||||
declare_mlir_python_extension(MLIRPythonExtension.AsyncDialectPasses
|
||||
MODULE_NAME _mlirAsyncPasses
|
||||
ADD_TO_PARENT MLIRPythonSources.Dialects.async_dialect
|
||||
|
|
|
@ -0,0 +1,26 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from mlir.ir import Type, Context
|
||||
|
||||
|
||||
class AnyOpType(Type):
|
||||
@staticmethod
|
||||
def isinstance(type: Type) -> bool: ...
|
||||
|
||||
@staticmethod
|
||||
def get(context: Optional[Context] = None) -> AnyOpType: ...
|
||||
|
||||
|
||||
class OperationType(Type):
|
||||
@staticmethod
|
||||
def isinstance(type: Type) -> bool: ...
|
||||
|
||||
@staticmethod
|
||||
def get(operation_name: str, context: Optional[Context] = None) -> OperationType: ...
|
||||
|
||||
@property
|
||||
def operation_name(self) -> str: ...
|
|
@ -18,6 +18,16 @@ def _get_symbol_ref_attr(value: Union[Attribute, str]):
|
|||
return FlatSymbolRefAttr.get(value)
|
||||
|
||||
|
||||
class CastOp:
|
||||
|
||||
def __init__(self, result_type: Type, target: Union[Operation, Value], *, loc=None, ip=None):
|
||||
super().__init__(
|
||||
result_type,
|
||||
_get_op_result_or_value(target),
|
||||
loc=loc,
|
||||
ip=ip)
|
||||
|
||||
|
||||
class GetClosestIsolatedParentOp:
|
||||
|
||||
def __init__(self, result_type: Type, target: Union[Operation, Value], *, loc=None, ip=None):
|
||||
|
|
|
@ -18,3 +18,4 @@ class FailurePropagationMode(Enum):
|
|||
return 2
|
||||
|
||||
from .._transform_ops_gen import *
|
||||
from ..._mlir_libs._mlirDialectsTransform import *
|
||||
|
|
|
@ -54,6 +54,14 @@ _add_capi_test_executable(mlir-capi-pass-test
|
|||
MLIRCAPITransforms
|
||||
)
|
||||
|
||||
_add_capi_test_executable(mlir-capi-pdl-test
|
||||
pdl.c
|
||||
LINK_LIBS PRIVATE
|
||||
MLIRCAPIIR
|
||||
MLIRCAPIRegisterEverything
|
||||
MLIRCAPIPDL
|
||||
)
|
||||
|
||||
_add_capi_test_executable(mlir-capi-sparse-tensor-test
|
||||
sparse_tensor.c
|
||||
LINK_LIBS PRIVATE
|
||||
|
@ -70,10 +78,10 @@ _add_capi_test_executable(mlir-capi-quant-test
|
|||
MLIRCAPIQuant
|
||||
)
|
||||
|
||||
_add_capi_test_executable(mlir-capi-pdl-test
|
||||
pdl.c
|
||||
_add_capi_test_executable(mlir-capi-transform-test
|
||||
transform.c
|
||||
LINK_LIBS PRIVATE
|
||||
MLIRCAPIIR
|
||||
MLIRCAPIRegisterEverything
|
||||
MLIRCAPIPDL
|
||||
MLIRCAPITransformDialect
|
||||
)
|
||||
|
|
|
@ -0,0 +1,88 @@
|
|||
//===- transform.c - Test of Transform dialect C API ----------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM
|
||||
// Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// RUN: mlir-capi-transform-test 2>&1 | FileCheck %s
|
||||
|
||||
#include "mlir-c/Dialect/Transform.h"
|
||||
#include "mlir-c/IR.h"
|
||||
#include "mlir-c/Support.h"
|
||||
|
||||
#include <assert.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
|
||||
// CHECK-LABEL: testAnyOpType
|
||||
void testAnyOpType(MlirContext ctx) {
|
||||
fprintf(stderr, "testAnyOpType\n");
|
||||
|
||||
MlirType parsedType = mlirTypeParseGet(
|
||||
ctx, mlirStringRefCreateFromCString("!transform.any_op"));
|
||||
MlirType constructedType = mlirTransformAnyOpTypeGet(ctx);
|
||||
|
||||
assert(!mlirTypeIsNull(parsedType) && "couldn't parse AnyOpType");
|
||||
assert(!mlirTypeIsNull(constructedType) && "couldn't construct AnyOpType");
|
||||
|
||||
// CHECK: equal: 1
|
||||
fprintf(stderr, "equal: %d\n", mlirTypeEqual(parsedType, constructedType));
|
||||
|
||||
// CHECK: parsedType isa AnyOpType: 1
|
||||
fprintf(stderr, "parsedType isa AnyOpType: %d\n",
|
||||
mlirTypeIsATransformAnyOpType(parsedType));
|
||||
// CHECK: parsedType isa OperationType: 0
|
||||
fprintf(stderr, "parsedType isa OperationType: %d\n",
|
||||
mlirTypeIsATransformOperationType(parsedType));
|
||||
|
||||
// CHECK: !transform.any_op
|
||||
mlirTypeDump(constructedType);
|
||||
|
||||
fprintf(stderr, "\n\n");
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testOperationType
|
||||
void testOperationType(MlirContext ctx) {
|
||||
fprintf(stderr, "testOperationType\n");
|
||||
|
||||
MlirType parsedType = mlirTypeParseGet(
|
||||
ctx, mlirStringRefCreateFromCString("!transform.op<\"foo.bar\">"));
|
||||
MlirType constructedType = mlirTransformOperationTypeGet(
|
||||
ctx, mlirStringRefCreateFromCString("foo.bar"));
|
||||
|
||||
assert(!mlirTypeIsNull(parsedType) && "couldn't parse AnyOpType");
|
||||
assert(!mlirTypeIsNull(constructedType) && "couldn't construct AnyOpType");
|
||||
|
||||
// CHECK: equal: 1
|
||||
fprintf(stderr, "equal: %d\n", mlirTypeEqual(parsedType, constructedType));
|
||||
|
||||
// CHECK: parsedType isa AnyOpType: 0
|
||||
fprintf(stderr, "parsedType isa AnyOpType: %d\n",
|
||||
mlirTypeIsATransformAnyOpType(parsedType));
|
||||
// CHECK: parsedType isa OperationType: 1
|
||||
fprintf(stderr, "parsedType isa OperationType: %d\n",
|
||||
mlirTypeIsATransformOperationType(parsedType));
|
||||
|
||||
// CHECK: operation name equal: 1
|
||||
MlirStringRef operationName =
|
||||
mlirTransformOperationTypeGetOperationName(constructedType);
|
||||
fprintf(stderr, "operation name equal: %d\n",
|
||||
mlirStringRefEqual(operationName,
|
||||
mlirStringRefCreateFromCString("foo.bar")));
|
||||
|
||||
// CHECK: !transform.op<"foo.bar">
|
||||
mlirTypeDump(constructedType);
|
||||
|
||||
fprintf(stderr, "\n\n");
|
||||
}
|
||||
|
||||
int main(void) {
|
||||
MlirContext ctx = mlirContextCreate();
|
||||
mlirDialectHandleRegisterDialect(mlirGetDialectHandle__transform__(), ctx);
|
||||
testAnyOpType(ctx);
|
||||
testOperationType(ctx);
|
||||
return EXIT_SUCCESS;
|
||||
}
|
|
@ -70,9 +70,10 @@ set(MLIR_TEST_DEPENDS
|
|||
mlir-capi-ir-test
|
||||
mlir-capi-llvm-test
|
||||
mlir-capi-pass-test
|
||||
mlir-capi-sparse-tensor-test
|
||||
mlir-capi-quant-test
|
||||
mlir-capi-pdl-test
|
||||
mlir-capi-quant-test
|
||||
mlir-capi-sparse-tensor-test
|
||||
mlir-capi-transform-test
|
||||
mlir-linalg-ods-yaml-gen
|
||||
mlir-lsp-server
|
||||
mlir-pdll-lsp-server
|
||||
|
|
|
@ -64,4 +64,6 @@ transform.sequence failures(propagate) {
|
|||
^bb0(%arg0: !pdl.operation):
|
||||
// CHECK: cast %{{.*}} : !pdl.operation to !transform.any_op
|
||||
%0 = cast %arg0: !pdl.operation to !transform.any_op
|
||||
// CHECK: cast %{{.*}} : !transform.any_op to !transform.op<"builtin.module">
|
||||
%1 = cast %0: !transform.any_op to !transform.op<"builtin.module">
|
||||
}
|
||||
|
|
|
@ -841,3 +841,45 @@ transform.with_pdl_patterns {
|
|||
transform.cast %2 : !transform.test_dialect_op to !pdl.operation
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
"test.some_op"() : () -> ()
|
||||
"other_dialect.other_op"() : () -> ()
|
||||
|
||||
transform.with_pdl_patterns {
|
||||
^bb0(%arg0: !pdl.operation):
|
||||
pdl.pattern @some : benefit(1) {
|
||||
%0 = pdl.operation "test.some_op"
|
||||
pdl.rewrite %0 with "transform.dialect"
|
||||
}
|
||||
|
||||
sequence %arg0 : !pdl.operation failures(propagate) {
|
||||
^bb1(%arg1: !pdl.operation):
|
||||
%0 = pdl_match @some in %arg1 : (!pdl.operation) -> !pdl.operation
|
||||
%2 = transform.cast %0 : !pdl.operation to !transform.op<"test.some_op">
|
||||
transform.cast %2 : !transform.op<"test.some_op"> to !pdl.operation
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
"test.some_op"() : () -> ()
|
||||
// expected-note @below {{payload operation}}
|
||||
"other_dialect.other_op"() : () -> ()
|
||||
|
||||
transform.with_pdl_patterns {
|
||||
^bb0(%arg0: !pdl.operation):
|
||||
pdl.pattern @other : benefit(1) {
|
||||
%0 = pdl.operation "other_dialect.other_op"
|
||||
pdl.rewrite %0 with "transform.dialect"
|
||||
}
|
||||
|
||||
sequence %arg0 : !pdl.operation failures(propagate) {
|
||||
^bb1(%arg1: !pdl.operation):
|
||||
%0 = pdl_match @other in %arg1 : (!pdl.operation) -> !pdl.operation
|
||||
// expected-error @below {{incompatible payload operation name}}
|
||||
%2 = transform.cast %0 : !pdl.operation to !transform.op<"test.some_op">
|
||||
transform.cast %2 : !transform.op<"test.some_op"> to !pdl.operation
|
||||
}
|
||||
}
|
||||
|
|
|
@ -60,9 +60,10 @@ tools = [
|
|||
'mlir-capi-ir-test',
|
||||
'mlir-capi-llvm-test',
|
||||
'mlir-capi-pass-test',
|
||||
'mlir-capi-sparse-tensor-test',
|
||||
'mlir-capi-quant-test',
|
||||
'mlir-capi-pdl-test',
|
||||
'mlir-capi-quant-test',
|
||||
'mlir-capi-sparse-tensor-test',
|
||||
'mlir-capi-transform-test',
|
||||
'mlir-cpu-runner',
|
||||
'mlir-linalg-ods-yaml-gen',
|
||||
'mlir-reduce',
|
||||
|
|
|
@ -15,6 +15,20 @@ def run(f):
|
|||
return f
|
||||
|
||||
|
||||
@run
|
||||
def testTypes():
|
||||
# CHECK-LABEL: TEST: testTypes
|
||||
# CHECK: !transform.any_op
|
||||
any_op = transform.AnyOpType.get()
|
||||
print(any_op)
|
||||
|
||||
# CHECK: !transform.op<"foo.bar">
|
||||
# CHECK: foo.bar
|
||||
concrete_op = transform.OperationType.get("foo.bar")
|
||||
print(concrete_op)
|
||||
print(concrete_op.operation_name)
|
||||
|
||||
|
||||
@run
|
||||
def testSequenceOp():
|
||||
sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE,
|
||||
|
|
|
@ -565,6 +565,23 @@ mlir_c_api_cc_library(
|
|||
],
|
||||
)
|
||||
|
||||
mlir_c_api_cc_library(
|
||||
name = "CAPITransformDialect",
|
||||
srcs = [
|
||||
"lib/CAPI/Dialect/Transform.cpp",
|
||||
],
|
||||
hdrs = [
|
||||
"include/mlir-c/Dialect/Transform.h",
|
||||
],
|
||||
capi_deps = [
|
||||
":CAPIIR",
|
||||
],
|
||||
includes = ["include"],
|
||||
deps = [
|
||||
":TransformDialect",
|
||||
],
|
||||
)
|
||||
|
||||
mlir_c_api_cc_library(
|
||||
name = "CAPIMLProgram",
|
||||
srcs = [
|
||||
|
|
Loading…
Reference in New Issue