From ada9aa5a228200cb71269c371308e82c42fd4abc Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Mon, 7 Jun 2021 18:33:29 +0200 Subject: [PATCH] [mlir] Make MemRef element type extensible Historically, MemRef only supported a restricted list of element types that were known to be storable in memory. This is unnecessarily restrictive given the open nature of MLIR's type system. Allow types to opt into being used as MemRef elements by implementing a type interface. For now, the interface is merely a declaration with no methods. Later, methods to query, e.g., the type size or whether a type can alias elements of another type may be added. Harden the "standard"-to-LLVM conversion against memrefs with non-builtin types. See https://llvm.discourse.group/t/rfc-memref-of-custom-types/3558. Depends On D103826 Reviewed By: rriddle Differential Revision: https://reviews.llvm.org/D103827 --- mlir/docs/Dialects/Builtin.md | 4 +++ mlir/include/mlir/IR/BuiltinTypes.h | 9 ++++- mlir/include/mlir/IR/BuiltinTypes.td | 33 +++++++++++++++++++ mlir/include/mlir/IR/CMakeLists.txt | 3 ++ .../StandardToLLVM/StandardToLLVM.cpp | 4 +++ mlir/lib/IR/BuiltinTypes.cpp | 6 ++++ .../convert-static-memref-ops.mlir | 20 +++++++++++ .../Conversion/StandardToLLVM/invalid.mlir | 1 + mlir/test/IR/parser.mlir | 3 ++ mlir/test/lib/Dialect/Test/CMakeLists.txt | 4 +-- mlir/test/lib/Dialect/Test/TestTypeDefs.td | 6 ++++ 11 files changed, 90 insertions(+), 3 deletions(-) diff --git a/mlir/docs/Dialects/Builtin.md b/mlir/docs/Dialects/Builtin.md index c48fc1bede68..b39506a39b5a 100644 --- a/mlir/docs/Dialects/Builtin.md +++ b/mlir/docs/Dialects/Builtin.md @@ -30,3 +30,7 @@ Operations. ## Types [include "Dialects/BuiltinTypes.md"] + +## Type Interfaces + +[include "Dialects/BuiltinTypeInterfaces.md"] diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h index 718fffd3e7b6..d858c3129091 100644 --- a/mlir/include/mlir/IR/BuiltinTypes.h +++ b/mlir/include/mlir/IR/BuiltinTypes.h @@ -192,6 +192,12 @@ public: #define GET_TYPEDEF_CLASSES #include "mlir/IR/BuiltinTypes.h.inc" +//===----------------------------------------------------------------------===// +// Tablegen Interface Declarations +//===----------------------------------------------------------------------===// + +#include "mlir/IR/BuiltinTypeInterfaces.h.inc" + namespace mlir { //===----------------------------------------------------------------------===// // MemRefType @@ -266,7 +272,8 @@ inline bool BaseMemRefType::classof(Type type) { } inline bool BaseMemRefType::isValidElementType(Type type) { - return type.isIntOrIndexOrFloat() || type.isa(); + return type.isIntOrIndexOrFloat() || type.isa() || + type.isa(); } inline bool FloatType::classof(Type type) { diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td index 349da5663f9d..85787afc4954 100644 --- a/mlir/include/mlir/IR/BuiltinTypes.td +++ b/mlir/include/mlir/IR/BuiltinTypes.td @@ -248,6 +248,31 @@ def Builtin_Integer : Builtin_Type<"Integer"> { }]; } +//===----------------------------------------------------------------------===// +// MemRefElementTypeInterface +//===----------------------------------------------------------------------===// + +def MemRefElementTypeInterface : TypeInterface<"MemRefElementTypeInterface"> { + let cppNamespace = "::mlir"; + let description = [{ + Indication that this type can be used as element in memref types. + + Implementing this interface establishes a contract between this type and the + memref type indicating that this type can be used as element of ranked or + unranked memrefs. The type is expected to: + + - model an entity stored in memory; + - have non-zero size. + + For example, scalar values such as integers can implement this interface, + but indicator types such as `void` or `unit` should not. + + The interface currently has no methods and is used by types to opt into + being memref elements. This may change in the future, in particular to + require types to provide their size or alignment given a data layout. + }]; +} + //===----------------------------------------------------------------------===// // MemRefType //===----------------------------------------------------------------------===// @@ -282,6 +307,14 @@ def Builtin_MemRef : Builtin_Type<"MemRef", "BaseMemRefType"> { on the rank. Other uses of this type are disallowed or will have undefined behavior. + Are accepted as elements: + + - built-in integer types; + - built-in index type; + - built-in floating point types; + - built-in vector types with elements of the above types; + - any other type implementing `MemRefElementTypeInterface`. + ##### Codegen of Unranked Memref Using unranked memref in codegen besides the case mentioned above is highly diff --git a/mlir/include/mlir/IR/CMakeLists.txt b/mlir/include/mlir/IR/CMakeLists.txt index 42e07811a4a5..b8b49aa425a9 100644 --- a/mlir/include/mlir/IR/CMakeLists.txt +++ b/mlir/include/mlir/IR/CMakeLists.txt @@ -24,6 +24,8 @@ add_public_tablegen_target(MLIRBuiltinOpsIncGen) set(LLVM_TARGET_DEFINITIONS BuiltinTypes.td) mlir_tablegen(BuiltinTypes.h.inc -gen-typedef-decls) mlir_tablegen(BuiltinTypes.cpp.inc -gen-typedef-defs) +mlir_tablegen(BuiltinTypeInterfaces.h.inc -gen-type-interface-decls) +mlir_tablegen(BuiltinTypeInterfaces.cpp.inc -gen-type-interface-defs) add_public_tablegen_target(MLIRBuiltinTypesIncGen) set(LLVM_TARGET_DEFINITIONS TensorEncoding.td) @@ -35,3 +37,4 @@ add_mlir_doc(BuiltinAttributes BuiltinAttributes Dialects/ -gen-attrdef-doc) add_mlir_doc(BuiltinLocationAttributes BuiltinLocationAttributes Dialects/ -gen-attrdef-doc) add_mlir_doc(BuiltinOps BuiltinOps Dialects/ -gen-op-doc) add_mlir_doc(BuiltinTypes BuiltinTypes Dialects/ -gen-typedef-doc) +add_mlir_doc(BuiltinTypes BuiltinTypeInterfaces Dialects/ -gen-type-interface-docs) diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp index dcbb4b336213..11d0cd6fdc76 100644 --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -349,6 +349,8 @@ Type LLVMTypeConverter::convertMemRefType(MemRefType type) { // unpack the `sizes` and `strides` arrays. SmallVector types = getMemRefDescriptorFields(type, /*unpackAggregates=*/false); + if (types.empty()) + return {}; return LLVM::LLVMStructType::getLiteral(&getContext(), types); } @@ -368,6 +370,8 @@ SmallVector LLVMTypeConverter::getUnrankedMemRefDescriptorFields() { } Type LLVMTypeConverter::convertUnrankedMemRefType(UnrankedMemRefType type) { + if (!convertType(type.getElementType())) + return {}; return LLVM::LLVMStructType::getLiteral(&getContext(), getUnrankedMemRefDescriptorFields()); } diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp index baadd8d0433c..77d64080de6e 100644 --- a/mlir/lib/IR/BuiltinTypes.cpp +++ b/mlir/lib/IR/BuiltinTypes.cpp @@ -31,6 +31,12 @@ using namespace mlir::detail; #define GET_TYPEDEF_CLASSES #include "mlir/IR/BuiltinTypes.cpp.inc" +//===----------------------------------------------------------------------===// +/// Tablegen Interface Definitions +//===----------------------------------------------------------------------===// + +#include "mlir/IR/BuiltinTypeInterfaces.cpp.inc" + //===----------------------------------------------------------------------===// // BuiltinDialect //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/StandardToLLVM/convert-static-memref-ops.mlir b/mlir/test/Conversion/StandardToLLVM/convert-static-memref-ops.mlir index 27623393148f..6df3c9494375 100644 --- a/mlir/test/Conversion/StandardToLLVM/convert-static-memref-ops.mlir +++ b/mlir/test/Conversion/StandardToLLVM/convert-static-memref-ops.mlir @@ -427,3 +427,23 @@ module attributes { dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry> } { return } } + +// ----- + +// Should not convert memrefs with unsupported types in any convention. + +// CHECK: @unsupported_memref_element_type +// CHECK-SAME: memref< +// CHECK-NOT: !llvm.struct +// BAREPTR: @unsupported_memref_element_type +// BAREPTR-SAME: memref< +// BAREPTR-NOT: !llvm.ptr +func private @unsupported_memref_element_type() -> memref<42 x !test.memref_element> + +// CHECK: @unsupported_unranked_memref_element_type +// CHECK-SAME: memref< +// CHECK-NOT: !llvm.struct +// BAREPTR: @unsupported_unranked_memref_element_type +// BAREPTR-SAME: memref< +// BAREPTR-NOT: !llvm.ptr +func private @unsupported_unranked_memref_element_type() -> memref<* x !test.memref_element> diff --git a/mlir/test/Conversion/StandardToLLVM/invalid.mlir b/mlir/test/Conversion/StandardToLLVM/invalid.mlir index 8dbc2bfddd80..5b6e7577cc77 100644 --- a/mlir/test/Conversion/StandardToLLVM/invalid.mlir +++ b/mlir/test/Conversion/StandardToLLVM/invalid.mlir @@ -6,3 +6,4 @@ func private @unsupported_signature() -> tensor<10 x i32> // ----- func private @partially_supported_signature() -> (vector<10 x i32>, tensor<10 x i32>) + diff --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir index 7e8810c7479d..2a3487cffe4c 100644 --- a/mlir/test/IR/parser.mlir +++ b/mlir/test/IR/parser.mlir @@ -178,6 +178,9 @@ func private @memref_with_complex_elems(memref<1x?xcomplex>) // CHECK: func private @memref_with_vector_elems(memref<1x?xvector<10xf32>>) func private @memref_with_vector_elems(memref<1x?xvector<10xf32>>) +// CHECK: func private @memref_with_custom_elem(memref<1x?x!test.memref_element>) +func private @memref_with_custom_elem(memref<1x?x!test.memref_element>) + // CHECK: func private @unranked_memref_with_complex_elems(memref<*xcomplex>) func private @unranked_memref_with_complex_elems(memref<*xcomplex>) diff --git a/mlir/test/lib/Dialect/Test/CMakeLists.txt b/mlir/test/lib/Dialect/Test/CMakeLists.txt index d1cf46ae5788..30fe52e15079 100644 --- a/mlir/test/lib/Dialect/Test/CMakeLists.txt +++ b/mlir/test/lib/Dialect/Test/CMakeLists.txt @@ -17,8 +17,8 @@ mlir_tablegen(TestAttrDefs.cpp.inc -gen-attrdef-defs) add_public_tablegen_target(MLIRTestAttrDefIncGen) set(LLVM_TARGET_DEFINITIONS TestTypeDefs.td) -mlir_tablegen(TestTypeDefs.h.inc -gen-typedef-decls) -mlir_tablegen(TestTypeDefs.cpp.inc -gen-typedef-defs) +mlir_tablegen(TestTypeDefs.h.inc -gen-typedef-decls -typedefs-dialect=test) +mlir_tablegen(TestTypeDefs.cpp.inc -gen-typedef-defs -typedefs-dialect=test) add_public_tablegen_target(MLIRTestTypeDefIncGen) diff --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td index 9821774eeede..a5ae219780b4 100644 --- a/mlir/test/lib/Dialect/Test/TestTypeDefs.td +++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td @@ -15,6 +15,7 @@ // To get the test dialect def. include "TestOps.td" +include "mlir/IR/BuiltinTypes.td" include "mlir/Interfaces/DataLayoutInterfaces.td" // All of the types will extend this class. @@ -176,4 +177,9 @@ def TestTypeWithLayoutType : Test_Type<"TestTypeWithLayout", [ }]; } +def TestMemRefElementType : Test_Type<"TestMemRefElementType", + [MemRefElementTypeInterface]> { + let mnemonic = "memref_element"; +} + #endif // TEST_TYPEDEFS