From fd6b404183ce1faa47a20a345d1f7d3486070f4f Mon Sep 17 00:00:00 2001 From: Mehdi Amini Date: Sat, 6 Nov 2021 07:14:17 +0000 Subject: [PATCH] Emit the boilerplate for Attribute printer/parser dialect dispatching from ODS Add a new `useDefaultAttributePrinterParser` boolean settings on the dialect (default to false for now) that emits the boilerplate to dispatch attribute parsing/printing to the auto-generated method. We will likely turn this on by default in the future. Differential Revision: https://reviews.llvm.org/D113329 --- mlir/include/mlir/IR/OpBase.td | 4 +++ mlir/include/mlir/TableGen/Dialect.h | 4 +++ mlir/lib/TableGen/Dialect.cpp | 4 +++ mlir/test/lib/Dialect/Test/TestAttributes.cpp | 21 ------------ mlir/test/lib/Dialect/Test/TestOps.td | 1 + mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp | 34 +++++++++++++++++++ mlir/tools/mlir-tblgen/DialectGen.cpp | 2 +- 7 files changed, 48 insertions(+), 22 deletions(-) diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td index 7ba93ad2476d..70a5f2942a2d 100644 --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -298,6 +298,10 @@ class Dialect { // If this dialect overrides the hook for op interface fallback. bit hasOperationInterfaceFallback = 0; + // If this dialect should use default generated attribute parser boilerplate: + // it'll dispatch the parsing to every individual attributes directly. + bit useDefaultAttributePrinterParser = 0; + // If this dialect overrides the hook for canonicalization patterns. bit hasCanonicalizer = 0; diff --git a/mlir/include/mlir/TableGen/Dialect.h b/mlir/include/mlir/TableGen/Dialect.h index 3030d6556b5b..7eb70030785b 100644 --- a/mlir/include/mlir/TableGen/Dialect.h +++ b/mlir/include/mlir/TableGen/Dialect.h @@ -74,6 +74,10 @@ public: /// Returns true if this dialect has fallback interfaces for its operations. bool hasOperationInterfaceFallback() const; + /// Returns true if this dialect should generate the default dispatch for + /// attribute printing/parsing. + bool useDefaultAttributePrinterParser() const; + // Returns whether two dialects are equal by checking the equality of the // underlying record. bool operator==(const Dialect &other) const; diff --git a/mlir/lib/TableGen/Dialect.cpp b/mlir/lib/TableGen/Dialect.cpp index ed775d236f11..bfaf7163f456 100644 --- a/mlir/lib/TableGen/Dialect.cpp +++ b/mlir/lib/TableGen/Dialect.cpp @@ -90,6 +90,10 @@ bool Dialect::hasOperationInterfaceFallback() const { return def->getValueAsBit("hasOperationInterfaceFallback"); } +bool Dialect::useDefaultAttributePrinterParser() const { + return def->getValueAsBit("useDefaultAttributePrinterParser"); +} + Dialect::EmitPrefix Dialect::getEmitAccessorPrefix() const { int prefix = def->getValueAsInt("emitAccessorPrefix"); if (prefix < 0 || prefix > static_cast(EmitPrefix::Both)) diff --git a/mlir/test/lib/Dialect/Test/TestAttributes.cpp b/mlir/test/lib/Dialect/Test/TestAttributes.cpp index e0c8ebbb7b93..d2e3f66c93ab 100644 --- a/mlir/test/lib/Dialect/Test/TestAttributes.cpp +++ b/mlir/test/lib/Dialect/Test/TestAttributes.cpp @@ -229,24 +229,3 @@ void TestDialect::registerAttributes() { #include "TestAttrDefs.cpp.inc" >(); } - -Attribute TestDialect::parseAttribute(DialectAsmParser &parser, - Type type) const { - StringRef attrTag; - if (failed(parser.parseKeyword(&attrTag))) - return Attribute(); - { - Attribute attr; - auto parseResult = generatedAttributeParser(parser, attrTag, type, attr); - if (parseResult.hasValue()) - return attr; - } - parser.emitError(parser.getNameLoc(), "unknown test attribute"); - return Attribute(); -} - -void TestDialect::printAttribute(Attribute attr, - DialectAsmPrinter &printer) const { - if (succeeded(generatedAttributePrinter(attr, printer))) - return; -} diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td index 90eb1d833c76..c5fcc5c55b64 100644 --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -34,6 +34,7 @@ def Test_Dialect : Dialect { let hasRegionResultAttrVerify = 1; let hasOperationInterfaceFallback = 1; let hasNonDefaultDestructor = 1; + let useDefaultAttributePrinterParser = 1; let dependentDialects = ["::mlir::DLTIDialect"]; let extraClassDeclaration = [{ diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp index 5e4cb4d73a39..f9e86a0eedcf 100644 --- a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp +++ b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp @@ -500,6 +500,34 @@ static ::mlir::OptionalParseResult generated{0}Parser( ::mlir::{0} &value) {{ )"; +/// The code block for default attribute parser/printer dispatch boilerplate. +/// {0}: the dialect fully qualified class name. +static const char *const dialectDefaultAttrPrinterParserDispatch = R"( +/// Parse an attribute registered to this dialect. +::mlir::Attribute {0}::parseAttribute(::mlir::DialectAsmParser &parser, + ::mlir::Type type) const {{ + ::llvm::SMLoc typeLoc = parser.getCurrentLocation(); + ::llvm::StringRef attrTag; + if (failed(parser.parseKeyword(&attrTag))) + return {{}; + {{ + ::mlir::Attribute attr; + auto parseResult = generatedAttributeParser(parser, attrTag, type, attr); + if (parseResult.hasValue()) + return attr; + } + parser.emitError(typeLoc) << "unknown attribute `" + << attrTag << "` in dialect `" << getNamespace() << "`"; + return {{}; +} +/// Print an attribute registered to this dialect. +void {0}::printAttribute(::mlir::Attribute attr, + ::mlir::DialectAsmPrinter &printer) const {{ + if (succeeded(generatedAttributePrinter(attr, printer))) + return; +} +)"; + /// The code block used to start the auto-generated printer function. /// /// {0}: The name of the base value type, e.g. Attribute or Type. @@ -986,6 +1014,12 @@ bool DefGenerator::emitDefs(StringRef selectedDialect) { << "::" << def.getCppClassName() << ")\n"; } + // Emit the default parser/printer for Attributes if the dialect asked for it. + if (valueType == "Attribute" && + defs.front().getDialect().useDefaultAttributePrinterParser()) + os << llvm::formatv(dialectDefaultAttrPrinterParserDispatch, + defs.front().getDialect().getCppClassName()); + return false; } diff --git a/mlir/tools/mlir-tblgen/DialectGen.cpp b/mlir/tools/mlir-tblgen/DialectGen.cpp index 2e5b98380538..7767b257312e 100644 --- a/mlir/tools/mlir-tblgen/DialectGen.cpp +++ b/mlir/tools/mlir-tblgen/DialectGen.cpp @@ -208,7 +208,7 @@ static void emitDialectDecl(Dialect &dialect, // Check for any attributes/types registered to this dialect. If there are, // add the hooks for parsing/printing. - if (!dialectAttrs.empty()) + if (!dialectAttrs.empty() || dialect.useDefaultAttributePrinterParser()) os << attrParserDecl; if (!dialectTypes.empty()) os << typeParserDecl;