[mlir][OpDSL] Add support for adding canonicalization patterns.
Extend OpDSL with a `defines` method that can set the `hasCanonicalizer` flag for an OpDSL operation. If the flag is set via `defines(Canonicalizer)` the operation needs to implement the `getCanonicalizationPatterns` method. The revision specifies the flag for linalg.fill_tensor and adds an empty `FillTensorOp::getCanonicalizationPatterns` implementation. This revision is a preparation step to replace linalg.fill by its OpDSL counterpart linalg.fill_tensor. The two are only functionally equivalent if both specify the same canonicalization patterns. The revision is thus a prerequisite for the linalg.fill replacement. Depends On D120725 Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D120726
This commit is contained in:
parent
8d7850705c
commit
d629645fcd
|
@ -55,6 +55,7 @@ def matmul(A=TensorDef(T1, S.M, S.K),
|
||||||
them to the same data type as the accumulator/output.
|
them to the same data type as the accumulator/output.
|
||||||
"""
|
"""
|
||||||
domain(D.m, D.n, D.k)
|
domain(D.m, D.n, D.k)
|
||||||
|
defines(Canonicalizer)
|
||||||
implements(ContractionOpInterface)
|
implements(ContractionOpInterface)
|
||||||
C[D.m, D.n] += TypeFn.cast_signed(
|
C[D.m, D.n] += TypeFn.cast_signed(
|
||||||
U, A[D.m, D.k]) * TypeFn.cast_signed(U, B[D.k, D.n])
|
U, A[D.m, D.k]) * TypeFn.cast_signed(U, B[D.k, D.n])
|
||||||
|
@ -78,6 +79,9 @@ An explicit iteration domain dimension order can be declared for the op via
|
||||||
Special identifying op interfaces can be declared for the op via
|
Special identifying op interfaces can be declared for the op via
|
||||||
`implements(interface1[, interface2...])`.
|
`implements(interface1[, interface2...])`.
|
||||||
|
|
||||||
|
Extra method definitions can be declared for the op via
|
||||||
|
`defines(definition1[, definition2...])`.
|
||||||
|
|
||||||
## Parameters
|
## Parameters
|
||||||
|
|
||||||
Structured operations take two types of runtime parameters namely scalars and
|
Structured operations take two types of runtime parameters namely scalars and
|
||||||
|
|
|
@ -2877,6 +2877,8 @@ metadata: !LinalgOpMetadata
|
||||||
the value operand, promoting it to the same data type as the output.
|
the value operand, promoting it to the same data type as the output.
|
||||||
implements:
|
implements:
|
||||||
- LinalgFillOpInterface
|
- LinalgFillOpInterface
|
||||||
|
defines:
|
||||||
|
- hasCanonicalizer
|
||||||
structured_op: !LinalgStructuredOpConfig
|
structured_op: !LinalgStructuredOpConfig
|
||||||
args:
|
args:
|
||||||
- !LinalgOperandDefConfig
|
- !LinalgOperandDefConfig
|
||||||
|
|
|
@ -509,6 +509,10 @@ void FillOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
||||||
FoldInsertPadIntoFill>(context);
|
FoldInsertPadIntoFill>(context);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO: Add the FillOp patterns when transitioning to the OpDSL FillOp.
|
||||||
|
void FillTensorOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
||||||
|
MLIRContext *context) {}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// GenericOps
|
// GenericOps
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -689,6 +689,16 @@ ConvolutionOpInterface = OpInterfaceDef("LinalgConvolutionOpInterface")
|
||||||
FillOpInterface = OpInterfaceDef("LinalgFillOpInterface")
|
FillOpInterface = OpInterfaceDef("LinalgFillOpInterface")
|
||||||
|
|
||||||
|
|
||||||
|
class OpDefinitionDef:
|
||||||
|
"""A method that an op implements."""
|
||||||
|
|
||||||
|
def __init__(self, def_name: str):
|
||||||
|
self.def_name = def_name
|
||||||
|
|
||||||
|
|
||||||
|
Canonicalizer = OpDefinitionDef("hasCanonicalizer")
|
||||||
|
|
||||||
|
|
||||||
class OpMetadataDef(YAMLObject):
|
class OpMetadataDef(YAMLObject):
|
||||||
"""Metadata about the op (generally not behavior impacting)."""
|
"""Metadata about the op (generally not behavior impacting)."""
|
||||||
yaml_tag = "!LinalgOpMetadata"
|
yaml_tag = "!LinalgOpMetadata"
|
||||||
|
@ -699,6 +709,7 @@ class OpMetadataDef(YAMLObject):
|
||||||
self.cpp_class_name = cpp_class_name if cpp_class_name is not None else name
|
self.cpp_class_name = cpp_class_name if cpp_class_name is not None else name
|
||||||
self.doc = doc
|
self.doc = doc
|
||||||
self.implements = [] # type: List[OpInterfaceDef]
|
self.implements = [] # type: List[OpInterfaceDef]
|
||||||
|
self.defines = [] # type: List[OpDefinitionsDef]
|
||||||
|
|
||||||
def to_yaml_custom_dict(self):
|
def to_yaml_custom_dict(self):
|
||||||
d = dict(
|
d = dict(
|
||||||
|
@ -708,6 +719,8 @@ class OpMetadataDef(YAMLObject):
|
||||||
)
|
)
|
||||||
if self.implements:
|
if self.implements:
|
||||||
d["implements"] = [intr.cpp_name for intr in self.implements]
|
d["implements"] = [intr.cpp_name for intr in self.implements]
|
||||||
|
if self.defines:
|
||||||
|
d["defines"] = [defi.def_name for defi in self.defines]
|
||||||
return d
|
return d
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -149,13 +149,21 @@ def linalg_structured_op(dsl_func=None,
|
||||||
return DefinedOpCallable(op_name, op_def)
|
return DefinedOpCallable(op_name, op_def)
|
||||||
|
|
||||||
|
|
||||||
|
def domain(*dimensions: DimDef):
|
||||||
|
if any(not isinstance(d, DimDef) for d in dimensions):
|
||||||
|
raise ValueError(f"Expected dimensions of type DimDef but got {dimensions}")
|
||||||
|
current_op_def().domain.extend(dimensions)
|
||||||
|
|
||||||
|
|
||||||
def implements(*interfaces: OpInterfaceDef):
|
def implements(*interfaces: OpInterfaceDef):
|
||||||
|
if any(not isinstance(intr, OpInterfaceDef) for intr in interfaces):
|
||||||
|
raise ValueError(
|
||||||
|
f"Expected interfaces of type OpInterfaceDef but got {interfaces}")
|
||||||
current_op_def().metadata.implements.extend(interfaces)
|
current_op_def().metadata.implements.extend(interfaces)
|
||||||
|
|
||||||
|
|
||||||
def domain(*dimensions: DimDef):
|
def defines(*definitions: OpDefinitionDef):
|
||||||
if current_op_def().domain:
|
if any(not isinstance(defi, OpDefinitionDef) for defi in definitions):
|
||||||
raise ValueError(f"Expected only one set of domain dimensions per operator")
|
raise ValueError(
|
||||||
if any(not isinstance(dim, DimDef) for dim in dimensions):
|
f"Expected definitions of type OpDefinitionDef but got {definitions}")
|
||||||
raise ValueError(f"Expected dimensions of type DimDef but got {dimensions}")
|
current_op_def().metadata.defines.extend(definitions)
|
||||||
current_op_def().domain.extend(dimensions)
|
|
||||||
|
|
|
@ -672,6 +672,7 @@ def fill_tensor(value=ScalarDef(T1), O=TensorDef(U, output=True)):
|
||||||
the value operand, promoting it to the same data type as the output.
|
the value operand, promoting it to the same data type as the output.
|
||||||
"""
|
"""
|
||||||
implements(FillOpInterface)
|
implements(FillOpInterface)
|
||||||
|
defines(Canonicalizer)
|
||||||
O[None] = TypeFn.cast_signed(U, value)
|
O[None] = TypeFn.cast_signed(U, value)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -333,3 +333,58 @@ structured_op: !LinalgStructuredOpConfig
|
||||||
# IMPL: Value [[VAL0:[a-z0-9]+]] = helper.buildUnaryFn(unary_funVal, block.getArgument(0))
|
# IMPL: Value [[VAL0:[a-z0-9]+]] = helper.buildUnaryFn(unary_funVal, block.getArgument(0))
|
||||||
# IMPL-NEXT: Value [[VAL1:[a-z0-9]+]] = helper.buildBinaryFn(binary_funVal, [[VAL0]], block.getArgument(0))
|
# IMPL-NEXT: Value [[VAL1:[a-z0-9]+]] = helper.buildBinaryFn(binary_funVal, [[VAL0]], block.getArgument(0))
|
||||||
# IMPL-NEXT: yields.push_back([[VAL1]])
|
# IMPL-NEXT: yields.push_back([[VAL1]])
|
||||||
|
|
||||||
|
# @linalg_structured_op
|
||||||
|
# def test5(value=ScalarDef(T1), O=TensorDef(U, output=True)):
|
||||||
|
# """Title.
|
||||||
|
|
||||||
|
# Detailed description.
|
||||||
|
# """
|
||||||
|
# implements(FillOpInterface)
|
||||||
|
# defines(Canonicalizer)
|
||||||
|
# O[None] = TypeFn.cast(U, value)
|
||||||
|
|
||||||
|
--- !LinalgOpConfig
|
||||||
|
metadata: !LinalgOpMetadata
|
||||||
|
name: test5
|
||||||
|
cpp_class_name: Test5Op
|
||||||
|
doc: |-
|
||||||
|
Title.
|
||||||
|
|
||||||
|
Detailed description.
|
||||||
|
implements:
|
||||||
|
- LinalgFillOpInterface
|
||||||
|
defines:
|
||||||
|
- hasCanonicalizer
|
||||||
|
structured_op: !LinalgStructuredOpConfig
|
||||||
|
args:
|
||||||
|
- !LinalgOperandDefConfig
|
||||||
|
name: value
|
||||||
|
kind: scalar
|
||||||
|
type_var: T1
|
||||||
|
- !LinalgOperandDefConfig
|
||||||
|
name: O
|
||||||
|
kind: output_tensor
|
||||||
|
type_var: U
|
||||||
|
shape_map: affine_map<() -> ()>
|
||||||
|
indexing_maps: !LinalgIndexingMapsConfig
|
||||||
|
static_indexing_maps:
|
||||||
|
- affine_map<() -> ()>
|
||||||
|
- affine_map<() -> ()>
|
||||||
|
iterator_types: []
|
||||||
|
assignments:
|
||||||
|
- !ScalarAssign
|
||||||
|
arg: O
|
||||||
|
value: !ScalarExpression
|
||||||
|
scalar_fn:
|
||||||
|
kind: type
|
||||||
|
fn_name: cast
|
||||||
|
type_var: U
|
||||||
|
operands:
|
||||||
|
- !ScalarExpression
|
||||||
|
scalar_arg: value
|
||||||
|
|
||||||
|
# ODS-LABEL: def Test5Op : LinalgStructuredBase_Op<"test5"
|
||||||
|
# ODS-NEXT: /*extraInterfaces=*/[LinalgFillOpInterface])>
|
||||||
|
|
||||||
|
# ODS: let hasCanonicalizer = 1;
|
||||||
|
|
|
@ -7,11 +7,14 @@ from mlir.dialects.linalg.opdsl.lang import *
|
||||||
# CHECK-LABEL: matmul
|
# CHECK-LABEL: matmul
|
||||||
# CHECK: implements:
|
# CHECK: implements:
|
||||||
# CHECK-NEXT: - LinalgContractionOpInterface
|
# CHECK-NEXT: - LinalgContractionOpInterface
|
||||||
|
# CHECK: defines:
|
||||||
|
# CHECK-NEXT: - hasCanonicalizer
|
||||||
@linalg_structured_op
|
@linalg_structured_op
|
||||||
def matmul(
|
def matmul(
|
||||||
A=TensorDef(T, S.M, S.K),
|
A=TensorDef(T, S.M, S.K),
|
||||||
B=TensorDef(T, S.K, S.N),
|
B=TensorDef(T, S.K, S.N),
|
||||||
C=TensorDef(U, S.M, S.N, output=True)):
|
C=TensorDef(U, S.M, S.N, output=True)):
|
||||||
implements(ContractionOpInterface)
|
implements(ContractionOpInterface)
|
||||||
|
defines(Canonicalizer)
|
||||||
C[D.m, D.n] += TypeFn.cast_signed(U, A[D.m, D.k]) * TypeFn.cast_signed(
|
C[D.m, D.n] += TypeFn.cast_signed(U, A[D.m, D.k]) * TypeFn.cast_signed(
|
||||||
U, B[D.k, D.n])
|
U, B[D.k, D.n])
|
|
@ -53,6 +53,7 @@ struct LinalgOpMetadata {
|
||||||
std::string cppClassName;
|
std::string cppClassName;
|
||||||
Optional<std::string> doc;
|
Optional<std::string> doc;
|
||||||
SmallVector<std::string> implements;
|
SmallVector<std::string> implements;
|
||||||
|
SmallVector<std::string> defines;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct SerializedAffineMap {
|
struct SerializedAffineMap {
|
||||||
|
@ -233,6 +234,7 @@ struct MappingTraits<LinalgOpMetadata> {
|
||||||
io.mapRequired("cpp_class_name", info.cppClassName);
|
io.mapRequired("cpp_class_name", info.cppClassName);
|
||||||
io.mapOptional("doc", info.doc);
|
io.mapOptional("doc", info.doc);
|
||||||
io.mapOptional("implements", info.implements);
|
io.mapOptional("implements", info.implements);
|
||||||
|
io.mapOptional("defines", info.defines);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -499,7 +501,8 @@ static const char bannerFormat[] = R"FMT(
|
||||||
// {3}: documentation (summary + description)
|
// {3}: documentation (summary + description)
|
||||||
// {4}: op attribute list
|
// {4}: op attribute list
|
||||||
// {5}: builder methods taking standalone attribute parameters
|
// {5}: builder methods taking standalone attribute parameters
|
||||||
// {6}: additional methods for attributes used by indexing maps
|
// {6}: additional method defintions
|
||||||
|
// {7}: additional methods for attributes used by indexing maps
|
||||||
static const char structuredOpOdsHeaderFormat[] = R"FMT(
|
static const char structuredOpOdsHeaderFormat[] = R"FMT(
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Op definition for {0}
|
// Op definition for {0}
|
||||||
|
@ -573,6 +576,7 @@ def {0} : LinalgStructuredBase_Op<"{1}", !listconcat([AttrSizedOperandSegments],
|
||||||
];
|
];
|
||||||
let hasCustomAssemblyFormat = 1;
|
let hasCustomAssemblyFormat = 1;
|
||||||
let hasFolder = 1;
|
let hasFolder = 1;
|
||||||
|
{6}
|
||||||
|
|
||||||
let extraClassDeclaration = structuredOpsBaseDecls # [{{
|
let extraClassDeclaration = structuredOpsBaseDecls # [{{
|
||||||
// Auto-generated.
|
// Auto-generated.
|
||||||
|
@ -589,7 +593,7 @@ def {0} : LinalgStructuredBase_Op<"{1}", !listconcat([AttrSizedOperandSegments],
|
||||||
// Generic methods.
|
// Generic methods.
|
||||||
static unsigned getNumRegionArgs();
|
static unsigned getNumRegionArgs();
|
||||||
std::string getLibraryCallName();
|
std::string getLibraryCallName();
|
||||||
{6}
|
{7}
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
)FMT";
|
)FMT";
|
||||||
|
@ -736,6 +740,12 @@ static LogicalResult generateNamedGenericOpOds(LinalgOpConfig &opConfig,
|
||||||
|
|
||||||
interfaceNameList = interleaveToString(opConfig.metadata->implements, ", ");
|
interfaceNameList = interleaveToString(opConfig.metadata->implements, ", ");
|
||||||
|
|
||||||
|
std::string definitionList;
|
||||||
|
for (const std::string &definition : opConfig.metadata->defines) {
|
||||||
|
static const char definitionFmt[] = "let {0} = 1;\n";
|
||||||
|
definitionList.append(llvm::formatv(definitionFmt, definition));
|
||||||
|
}
|
||||||
|
|
||||||
if (llvm::any_of(opConfig.structuredOp->args, [](LinalgOperandDef &arg) {
|
if (llvm::any_of(opConfig.structuredOp->args, [](LinalgOperandDef &arg) {
|
||||||
return isAttribute(arg.kind);
|
return isAttribute(arg.kind);
|
||||||
})) {
|
})) {
|
||||||
|
@ -794,7 +804,7 @@ static LogicalResult generateNamedGenericOpOds(LinalgOpConfig &opConfig,
|
||||||
os << llvm::formatv(structuredOpOdsHeaderFormat,
|
os << llvm::formatv(structuredOpOdsHeaderFormat,
|
||||||
opConfig.metadata->cppClassName, opConfig.metadata->name,
|
opConfig.metadata->cppClassName, opConfig.metadata->name,
|
||||||
interfaceNameList, doc, attrList, attrBuilder,
|
interfaceNameList, doc, attrList, attrBuilder,
|
||||||
attrMethods);
|
definitionList, attrMethods);
|
||||||
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue