[mlir][OpDSL] Add support for adding canonicalization patterns.

Extend OpDSL with a `defines` method that can set the `hasCanonicalizer` flag for an OpDSL operation. If the flag is set via `defines(Canonicalizer)` the operation needs to implement the `getCanonicalizationPatterns` method. The revision specifies the flag for linalg.fill_tensor and adds an empty `FillTensorOp::getCanonicalizationPatterns` implementation.

This revision is a preparation step to replace linalg.fill by its OpDSL counterpart linalg.fill_tensor. The two are only functionally equivalent if both specify the same canonicalization patterns. The revision is thus a prerequisite for the linalg.fill replacement.

Depends On D120725

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D120726
This commit is contained in:
gysit 2022-03-08 15:56:40 +00:00
parent 8d7850705c
commit d629645fcd
9 changed files with 109 additions and 9 deletions

View File

@ -55,6 +55,7 @@ def matmul(A=TensorDef(T1, S.M, S.K),
them to the same data type as the accumulator/output. them to the same data type as the accumulator/output.
""" """
domain(D.m, D.n, D.k) domain(D.m, D.n, D.k)
defines(Canonicalizer)
implements(ContractionOpInterface) implements(ContractionOpInterface)
C[D.m, D.n] += TypeFn.cast_signed( C[D.m, D.n] += TypeFn.cast_signed(
U, A[D.m, D.k]) * TypeFn.cast_signed(U, B[D.k, D.n]) U, A[D.m, D.k]) * TypeFn.cast_signed(U, B[D.k, D.n])
@ -78,6 +79,9 @@ An explicit iteration domain dimension order can be declared for the op via
Special identifying op interfaces can be declared for the op via Special identifying op interfaces can be declared for the op via
`implements(interface1[, interface2...])`. `implements(interface1[, interface2...])`.
Extra method definitions can be declared for the op via
`defines(definition1[, definition2...])`.
## Parameters ## Parameters
Structured operations take two types of runtime parameters namely scalars and Structured operations take two types of runtime parameters namely scalars and

View File

@ -2877,6 +2877,8 @@ metadata: !LinalgOpMetadata
the value operand, promoting it to the same data type as the output. the value operand, promoting it to the same data type as the output.
implements: implements:
- LinalgFillOpInterface - LinalgFillOpInterface
defines:
- hasCanonicalizer
structured_op: !LinalgStructuredOpConfig structured_op: !LinalgStructuredOpConfig
args: args:
- !LinalgOperandDefConfig - !LinalgOperandDefConfig

View File

@ -509,6 +509,10 @@ void FillOp::getCanonicalizationPatterns(RewritePatternSet &results,
FoldInsertPadIntoFill>(context); FoldInsertPadIntoFill>(context);
} }
// TODO: Add the FillOp patterns when transitioning to the OpDSL FillOp.
void FillTensorOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// GenericOps // GenericOps
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -689,6 +689,16 @@ ConvolutionOpInterface = OpInterfaceDef("LinalgConvolutionOpInterface")
FillOpInterface = OpInterfaceDef("LinalgFillOpInterface") FillOpInterface = OpInterfaceDef("LinalgFillOpInterface")
class OpDefinitionDef:
"""A method that an op implements."""
def __init__(self, def_name: str):
self.def_name = def_name
Canonicalizer = OpDefinitionDef("hasCanonicalizer")
class OpMetadataDef(YAMLObject): class OpMetadataDef(YAMLObject):
"""Metadata about the op (generally not behavior impacting).""" """Metadata about the op (generally not behavior impacting)."""
yaml_tag = "!LinalgOpMetadata" yaml_tag = "!LinalgOpMetadata"
@ -699,6 +709,7 @@ class OpMetadataDef(YAMLObject):
self.cpp_class_name = cpp_class_name if cpp_class_name is not None else name self.cpp_class_name = cpp_class_name if cpp_class_name is not None else name
self.doc = doc self.doc = doc
self.implements = [] # type: List[OpInterfaceDef] self.implements = [] # type: List[OpInterfaceDef]
self.defines = [] # type: List[OpDefinitionsDef]
def to_yaml_custom_dict(self): def to_yaml_custom_dict(self):
d = dict( d = dict(
@ -708,6 +719,8 @@ class OpMetadataDef(YAMLObject):
) )
if self.implements: if self.implements:
d["implements"] = [intr.cpp_name for intr in self.implements] d["implements"] = [intr.cpp_name for intr in self.implements]
if self.defines:
d["defines"] = [defi.def_name for defi in self.defines]
return d return d

View File

@ -149,13 +149,21 @@ def linalg_structured_op(dsl_func=None,
return DefinedOpCallable(op_name, op_def) return DefinedOpCallable(op_name, op_def)
def domain(*dimensions: DimDef):
if any(not isinstance(d, DimDef) for d in dimensions):
raise ValueError(f"Expected dimensions of type DimDef but got {dimensions}")
current_op_def().domain.extend(dimensions)
def implements(*interfaces: OpInterfaceDef): def implements(*interfaces: OpInterfaceDef):
if any(not isinstance(intr, OpInterfaceDef) for intr in interfaces):
raise ValueError(
f"Expected interfaces of type OpInterfaceDef but got {interfaces}")
current_op_def().metadata.implements.extend(interfaces) current_op_def().metadata.implements.extend(interfaces)
def domain(*dimensions: DimDef): def defines(*definitions: OpDefinitionDef):
if current_op_def().domain: if any(not isinstance(defi, OpDefinitionDef) for defi in definitions):
raise ValueError(f"Expected only one set of domain dimensions per operator") raise ValueError(
if any(not isinstance(dim, DimDef) for dim in dimensions): f"Expected definitions of type OpDefinitionDef but got {definitions}")
raise ValueError(f"Expected dimensions of type DimDef but got {dimensions}") current_op_def().metadata.defines.extend(definitions)
current_op_def().domain.extend(dimensions)

View File

@ -672,6 +672,7 @@ def fill_tensor(value=ScalarDef(T1), O=TensorDef(U, output=True)):
the value operand, promoting it to the same data type as the output. the value operand, promoting it to the same data type as the output.
""" """
implements(FillOpInterface) implements(FillOpInterface)
defines(Canonicalizer)
O[None] = TypeFn.cast_signed(U, value) O[None] = TypeFn.cast_signed(U, value)

View File

@ -333,3 +333,58 @@ structured_op: !LinalgStructuredOpConfig
# IMPL: Value [[VAL0:[a-z0-9]+]] = helper.buildUnaryFn(unary_funVal, block.getArgument(0)) # IMPL: Value [[VAL0:[a-z0-9]+]] = helper.buildUnaryFn(unary_funVal, block.getArgument(0))
# IMPL-NEXT: Value [[VAL1:[a-z0-9]+]] = helper.buildBinaryFn(binary_funVal, [[VAL0]], block.getArgument(0)) # IMPL-NEXT: Value [[VAL1:[a-z0-9]+]] = helper.buildBinaryFn(binary_funVal, [[VAL0]], block.getArgument(0))
# IMPL-NEXT: yields.push_back([[VAL1]]) # IMPL-NEXT: yields.push_back([[VAL1]])
# @linalg_structured_op
# def test5(value=ScalarDef(T1), O=TensorDef(U, output=True)):
# """Title.
# Detailed description.
# """
# implements(FillOpInterface)
# defines(Canonicalizer)
# O[None] = TypeFn.cast(U, value)
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: test5
cpp_class_name: Test5Op
doc: |-
Title.
Detailed description.
implements:
- LinalgFillOpInterface
defines:
- hasCanonicalizer
structured_op: !LinalgStructuredOpConfig
args:
- !LinalgOperandDefConfig
name: value
kind: scalar
type_var: T1
- !LinalgOperandDefConfig
name: O
kind: output_tensor
type_var: U
shape_map: affine_map<() -> ()>
indexing_maps: !LinalgIndexingMapsConfig
static_indexing_maps:
- affine_map<() -> ()>
- affine_map<() -> ()>
iterator_types: []
assignments:
- !ScalarAssign
arg: O
value: !ScalarExpression
scalar_fn:
kind: type
fn_name: cast
type_var: U
operands:
- !ScalarExpression
scalar_arg: value
# ODS-LABEL: def Test5Op : LinalgStructuredBase_Op<"test5"
# ODS-NEXT: /*extraInterfaces=*/[LinalgFillOpInterface])>
# ODS: let hasCanonicalizer = 1;

View File

@ -7,11 +7,14 @@ from mlir.dialects.linalg.opdsl.lang import *
# CHECK-LABEL: matmul # CHECK-LABEL: matmul
# CHECK: implements: # CHECK: implements:
# CHECK-NEXT: - LinalgContractionOpInterface # CHECK-NEXT: - LinalgContractionOpInterface
# CHECK: defines:
# CHECK-NEXT: - hasCanonicalizer
@linalg_structured_op @linalg_structured_op
def matmul( def matmul(
A=TensorDef(T, S.M, S.K), A=TensorDef(T, S.M, S.K),
B=TensorDef(T, S.K, S.N), B=TensorDef(T, S.K, S.N),
C=TensorDef(U, S.M, S.N, output=True)): C=TensorDef(U, S.M, S.N, output=True)):
implements(ContractionOpInterface) implements(ContractionOpInterface)
defines(Canonicalizer)
C[D.m, D.n] += TypeFn.cast_signed(U, A[D.m, D.k]) * TypeFn.cast_signed( C[D.m, D.n] += TypeFn.cast_signed(U, A[D.m, D.k]) * TypeFn.cast_signed(
U, B[D.k, D.n]) U, B[D.k, D.n])

View File

@ -53,6 +53,7 @@ struct LinalgOpMetadata {
std::string cppClassName; std::string cppClassName;
Optional<std::string> doc; Optional<std::string> doc;
SmallVector<std::string> implements; SmallVector<std::string> implements;
SmallVector<std::string> defines;
}; };
struct SerializedAffineMap { struct SerializedAffineMap {
@ -233,6 +234,7 @@ struct MappingTraits<LinalgOpMetadata> {
io.mapRequired("cpp_class_name", info.cppClassName); io.mapRequired("cpp_class_name", info.cppClassName);
io.mapOptional("doc", info.doc); io.mapOptional("doc", info.doc);
io.mapOptional("implements", info.implements); io.mapOptional("implements", info.implements);
io.mapOptional("defines", info.defines);
} }
}; };
@ -499,7 +501,8 @@ static const char bannerFormat[] = R"FMT(
// {3}: documentation (summary + description) // {3}: documentation (summary + description)
// {4}: op attribute list // {4}: op attribute list
// {5}: builder methods taking standalone attribute parameters // {5}: builder methods taking standalone attribute parameters
// {6}: additional methods for attributes used by indexing maps // {6}: additional method defintions
// {7}: additional methods for attributes used by indexing maps
static const char structuredOpOdsHeaderFormat[] = R"FMT( static const char structuredOpOdsHeaderFormat[] = R"FMT(
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Op definition for {0} // Op definition for {0}
@ -573,6 +576,7 @@ def {0} : LinalgStructuredBase_Op<"{1}", !listconcat([AttrSizedOperandSegments],
]; ];
let hasCustomAssemblyFormat = 1; let hasCustomAssemblyFormat = 1;
let hasFolder = 1; let hasFolder = 1;
{6}
let extraClassDeclaration = structuredOpsBaseDecls # [{{ let extraClassDeclaration = structuredOpsBaseDecls # [{{
// Auto-generated. // Auto-generated.
@ -589,7 +593,7 @@ def {0} : LinalgStructuredBase_Op<"{1}", !listconcat([AttrSizedOperandSegments],
// Generic methods. // Generic methods.
static unsigned getNumRegionArgs(); static unsigned getNumRegionArgs();
std::string getLibraryCallName(); std::string getLibraryCallName();
{6} {7}
}]; }];
} }
)FMT"; )FMT";
@ -736,6 +740,12 @@ static LogicalResult generateNamedGenericOpOds(LinalgOpConfig &opConfig,
interfaceNameList = interleaveToString(opConfig.metadata->implements, ", "); interfaceNameList = interleaveToString(opConfig.metadata->implements, ", ");
std::string definitionList;
for (const std::string &definition : opConfig.metadata->defines) {
static const char definitionFmt[] = "let {0} = 1;\n";
definitionList.append(llvm::formatv(definitionFmt, definition));
}
if (llvm::any_of(opConfig.structuredOp->args, [](LinalgOperandDef &arg) { if (llvm::any_of(opConfig.structuredOp->args, [](LinalgOperandDef &arg) {
return isAttribute(arg.kind); return isAttribute(arg.kind);
})) { })) {
@ -794,7 +804,7 @@ static LogicalResult generateNamedGenericOpOds(LinalgOpConfig &opConfig,
os << llvm::formatv(structuredOpOdsHeaderFormat, os << llvm::formatv(structuredOpOdsHeaderFormat,
opConfig.metadata->cppClassName, opConfig.metadata->name, opConfig.metadata->cppClassName, opConfig.metadata->name,
interfaceNameList, doc, attrList, attrBuilder, interfaceNameList, doc, attrList, attrBuilder,
attrMethods); definitionList, attrMethods);
return success(); return success();
} }